From c3503fbbdef7917b3ff8a6f9ac8314379725650e Mon Sep 17 00:00:00 2001 From: Thales Maciel Date: Wed, 25 Feb 2026 10:03:32 -0300 Subject: [PATCH] Add vocabulary correction pipeline and example config --- README.md | 39 +++++- config.example.json | 50 +++++++ src/aiprocess.py | 23 +++- src/config.py | 175 +++++++++++++++++++++++- src/leld.py | 54 +++++++- src/vocabulary.py | 280 +++++++++++++++++++++++++++++++++++++++ tests/test_config.py | 84 ++++++++++++ tests/test_leld.py | 107 +++++++++++++-- tests/test_vocabulary.py | 76 +++++++++++ 9 files changed, 865 insertions(+), 23 deletions(-) create mode 100644 config.example.json create mode 100644 src/vocabulary.py create mode 100644 tests/test_vocabulary.py diff --git a/README.md b/README.md index cd1a15d..b2d4633 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # lel -Python X11 STT daemon that records audio, runs Whisper, and injects text. It can optionally run local AI post-processing before injection. +Python X11 STT daemon that records audio, runs Whisper, applies local AI cleanup, and injects text. ## Requirements @@ -92,21 +92,50 @@ Create `~/.config/lel/config.json`: "stt": { "model": "base", "device": "cpu" }, "injection": { "backend": "clipboard" }, "ai": { "enabled": true }, - "logging": { "log_transcript": false } + "logging": { "log_transcript": false }, + "vocabulary": { + "replacements": [ + { "from": "Martha", "to": "Marta" }, + { "from": "docker", "to": "Docker" } + ], + "terms": ["Systemd", "Kubernetes"], + "max_rules": 500, + "max_terms": 500 + }, + "domain_inference": { "enabled": true, "mode": "auto" } } ``` Recording input can be a device index (preferred) or a substring of the device name. -`ai.enabled` controls local cleanup. When enabled, the LLM model is downloaded -on first use to `~/.cache/lel/models/` and uses the locked Llama-3.2-3B GGUF -model. +`ai.enabled` is accepted for compatibility but currently has no runtime effect. +AI cleanup is always enabled and uses the locked local Llama-3.2-3B GGUF model +downloaded to `~/.cache/lel/models/` on first use. `logging.log_transcript` controls whether recognized/processed text is written to logs. This is disabled by default. `-v/--verbose` also enables transcript logging and llama.cpp logs; llama logs are prefixed with `llama::`. +Vocabulary correction: + +- `vocabulary.replacements` is deterministic correction (`from -> to`). +- `vocabulary.terms` is a preferred spelling list used as hinting context. +- Wildcards are intentionally rejected (`*`, `?`, `[`, `]`, `{`, `}`) to avoid ambiguous rules. +- Rules are deduplicated case-insensitively; conflicting replacements are rejected. +- Limits are bounded by `max_rules` and `max_terms`. + +Domain inference: + +- `domain_inference.mode` currently supports `auto`. +- Domain context is advisory only and is used to improve cleanup prompts. +- When confidence is low, it falls back to `general` context. + +STT hinting: + +- Vocabulary is passed to Whisper as `hotwords`/`initial_prompt` only when those + arguments are supported by the installed `faster-whisper` runtime. + ## systemd user service ```bash diff --git a/config.example.json b/config.example.json new file mode 100644 index 0000000..e254ebf --- /dev/null +++ b/config.example.json @@ -0,0 +1,50 @@ +{ + "daemon": { + "hotkey": "Cmd+m" + }, + "recording": { + "input": "" + }, + "stt": { + "model": "base", + "device": "cpu" + }, + "injection": { + "backend": "clipboard" + }, + "ai": { + "enabled": true + }, + "logging": { + "log_transcript": true + }, + "vocabulary": { + "replacements": [ + { + "from": "Martha", + "to": "Marta" + }, + { + "from": "docker", + "to": "Docker" + }, + { + "from": "system d", + "to": "systemd" + } + ], + "terms": [ + "Marta", + "Docker", + "systemd", + "Kubernetes", + "PostgreSQL" + ], + "max_rules": 500, + "max_terms": 500 + }, + "domain_inference": { + "enabled": true, + "mode": "auto" + } +} diff --git a/src/aiprocess.py b/src/aiprocess.py index b02badf..ec6ee91 100644 --- a/src/aiprocess.py +++ b/src/aiprocess.py @@ -24,6 +24,9 @@ SYSTEM_PROMPT = ( "- Remove filler words (um/uh/like)\n" "- Remove false starts\n" "- Remove self-corrections.\n" + "- If a section exists, apply only the listed corrections.\n" + "- Keep dictionary spellings exactly as provided.\n" + "- Treat domain hints as advisory only; never invent context-specific jargon.\n" "- Output ONLY the cleaned text, no commentary.\n\n" "Examples:\n" " - \"Hey, schedule that for 5 PM, I mean 4 PM\" -> \"Hey, schedule that for 4 PM\"\n" @@ -49,9 +52,23 @@ class LlamaProcessor: verbose=verbose, ) - def process(self, text: str, lang: str = "en") -> str: - user_content = f"{text}" - user_content = f"{lang}\n{user_content}" + def process( + self, + text: str, + lang: str = "en", + *, + dictionary_context: str = "", + domain_name: str = "general", + domain_confidence: float = 0.0, + ) -> str: + blocks = [ + f"{lang}", + f'', + ] + if dictionary_context.strip(): + blocks.append(f"\n{dictionary_context.strip()}\n") + blocks.append(f"{text}") + user_content = "\n".join(blocks) response = self.client.create_chat_completion( messages=[ {"role": "system", "content": SYSTEM_PROMPT}, diff --git a/src/config.py b/src/config.py index a715aa2..5c828b7 100644 --- a/src/config.py +++ b/src/config.py @@ -12,7 +12,11 @@ DEFAULT_HOTKEY = "Cmd+m" DEFAULT_STT_MODEL = "base" DEFAULT_STT_DEVICE = "cpu" DEFAULT_INJECTION_BACKEND = "clipboard" +DEFAULT_VOCAB_LIMIT = 500 +DEFAULT_DOMAIN_INFERENCE_MODE = "auto" ALLOWED_INJECTION_BACKENDS = {"clipboard", "injection"} +ALLOWED_DOMAIN_INFERENCE_MODES = {"auto"} +WILDCARD_CHARS = set("*?[]{}") @dataclass @@ -46,6 +50,26 @@ class LoggingConfig: log_transcript: bool = False +@dataclass +class VocabularyReplacement: + source: str + target: str + + +@dataclass +class VocabularyConfig: + replacements: list[VocabularyReplacement] = field(default_factory=list) + terms: list[str] = field(default_factory=list) + max_rules: int = DEFAULT_VOCAB_LIMIT + max_terms: int = DEFAULT_VOCAB_LIMIT + + +@dataclass +class DomainInferenceConfig: + enabled: bool = True + mode: str = DEFAULT_DOMAIN_INFERENCE_MODE + + @dataclass class Config: daemon: DaemonConfig = field(default_factory=DaemonConfig) @@ -54,6 +78,8 @@ class Config: injection: InjectionConfig = field(default_factory=InjectionConfig) ai: AiConfig = field(default_factory=AiConfig) logging: LoggingConfig = field(default_factory=LoggingConfig) + vocabulary: VocabularyConfig = field(default_factory=VocabularyConfig) + domain_inference: DomainInferenceConfig = field(default_factory=DomainInferenceConfig) def load(path: str | None) -> Config: @@ -102,10 +128,43 @@ def validate(cfg: Config) -> None: if not isinstance(cfg.logging.log_transcript, bool): raise ValueError("logging.log_transcript must be boolean") + cfg.vocabulary.max_rules = _validated_limit(cfg.vocabulary.max_rules, "vocabulary.max_rules") + cfg.vocabulary.max_terms = _validated_limit(cfg.vocabulary.max_terms, "vocabulary.max_terms") + + if len(cfg.vocabulary.replacements) > cfg.vocabulary.max_rules: + raise ValueError( + f"vocabulary.replacements cannot exceed vocabulary.max_rules ({cfg.vocabulary.max_rules})" + ) + if len(cfg.vocabulary.terms) > cfg.vocabulary.max_terms: + raise ValueError( + f"vocabulary.terms cannot exceed vocabulary.max_terms ({cfg.vocabulary.max_terms})" + ) + + cfg.vocabulary.replacements = _validate_replacements(cfg.vocabulary.replacements) + cfg.vocabulary.terms = _validate_terms(cfg.vocabulary.terms) + + if not isinstance(cfg.domain_inference.enabled, bool): + raise ValueError("domain_inference.enabled must be boolean") + mode = cfg.domain_inference.mode.strip().lower() + if mode not in ALLOWED_DOMAIN_INFERENCE_MODES: + allowed = ", ".join(sorted(ALLOWED_DOMAIN_INFERENCE_MODES)) + raise ValueError(f"domain_inference.mode must be one of: {allowed}") + cfg.domain_inference.mode = mode + def _from_dict(data: dict[str, Any], cfg: Config) -> Config: has_sections = any( - key in data for key in ("daemon", "recording", "stt", "injection", "ai", "logging") + key in data + for key in ( + "daemon", + "recording", + "stt", + "injection", + "ai", + "logging", + "vocabulary", + "domain_inference", + ) ) if has_sections: daemon = _ensure_dict(data.get("daemon"), "daemon") @@ -114,6 +173,8 @@ def _from_dict(data: dict[str, Any], cfg: Config) -> Config: injection = _ensure_dict(data.get("injection"), "injection") ai = _ensure_dict(data.get("ai"), "ai") logging_cfg = _ensure_dict(data.get("logging"), "logging") + vocabulary = _ensure_dict(data.get("vocabulary"), "vocabulary") + domain_inference = _ensure_dict(data.get("domain_inference"), "domain_inference") if "hotkey" in daemon: cfg.daemon.hotkey = _as_nonempty_str(daemon["hotkey"], "daemon.hotkey") @@ -129,6 +190,22 @@ def _from_dict(data: dict[str, Any], cfg: Config) -> Config: cfg.ai.enabled = _as_bool(ai["enabled"], "ai.enabled") if "log_transcript" in logging_cfg: cfg.logging.log_transcript = _as_bool(logging_cfg["log_transcript"], "logging.log_transcript") + if "replacements" in vocabulary: + cfg.vocabulary.replacements = _as_replacements(vocabulary["replacements"]) + if "terms" in vocabulary: + cfg.vocabulary.terms = _as_terms(vocabulary["terms"]) + if "max_rules" in vocabulary: + cfg.vocabulary.max_rules = _as_int(vocabulary["max_rules"], "vocabulary.max_rules") + if "max_terms" in vocabulary: + cfg.vocabulary.max_terms = _as_int(vocabulary["max_terms"], "vocabulary.max_terms") + if "enabled" in domain_inference: + cfg.domain_inference.enabled = _as_bool( + domain_inference["enabled"], "domain_inference.enabled" + ) + if "mode" in domain_inference: + cfg.domain_inference.mode = _as_nonempty_str( + domain_inference["mode"], "domain_inference.mode" + ) return cfg if "hotkey" in data: @@ -170,6 +247,12 @@ def _as_bool(value: Any, field_name: str) -> bool: return value +def _as_int(value: Any, field_name: str) -> int: + if isinstance(value, bool) or not isinstance(value, int): + raise ValueError(f"{field_name} must be an integer") + return value + + def _as_recording_input(value: Any) -> str | int | None: if value is None: return None @@ -178,3 +261,93 @@ def _as_recording_input(value: Any) -> str | int | None: if isinstance(value, (str, int)): return value raise ValueError("recording.input must be string, integer, or null") + + +def _as_replacements(value: Any) -> list[VocabularyReplacement]: + if not isinstance(value, list): + raise ValueError("vocabulary.replacements must be a list") + replacements: list[VocabularyReplacement] = [] + for i, item in enumerate(value): + if not isinstance(item, dict): + raise ValueError(f"vocabulary.replacements[{i}] must be an object") + if "from" not in item: + raise ValueError(f"vocabulary.replacements[{i}].from is required") + if "to" not in item: + raise ValueError(f"vocabulary.replacements[{i}].to is required") + source = _as_nonempty_str(item["from"], f"vocabulary.replacements[{i}].from") + target = _as_nonempty_str(item["to"], f"vocabulary.replacements[{i}].to") + replacements.append(VocabularyReplacement(source=source, target=target)) + return replacements + + +def _as_terms(value: Any) -> list[str]: + if not isinstance(value, list): + raise ValueError("vocabulary.terms must be a list") + terms: list[str] = [] + for i, item in enumerate(value): + terms.append(_as_nonempty_str(item, f"vocabulary.terms[{i}]")) + return terms + + +def _validated_limit(value: int, field_name: str) -> int: + if isinstance(value, bool) or not isinstance(value, int): + raise ValueError(f"{field_name} must be an integer") + if value <= 0: + raise ValueError(f"{field_name} must be positive") + if value > 5000: + raise ValueError(f"{field_name} cannot exceed 5000") + return value + + +def _validate_replacements(value: list[VocabularyReplacement]) -> list[VocabularyReplacement]: + deduped: list[VocabularyReplacement] = [] + seen: dict[str, str] = {} + for i, item in enumerate(value): + source = item.source.strip() + target = item.target.strip() + if not source: + raise ValueError(f"vocabulary.replacements[{i}].from cannot be empty") + if not target: + raise ValueError(f"vocabulary.replacements[{i}].to cannot be empty") + if source == target: + raise ValueError(f"vocabulary.replacements[{i}] cannot map a term to itself") + if "\n" in source or "\n" in target: + raise ValueError(f"vocabulary.replacements[{i}] cannot contain newlines") + if any(ch in source for ch in WILDCARD_CHARS): + raise ValueError( + f"vocabulary.replacements[{i}].from cannot contain wildcard characters" + ) + + source_key = _normalize_key(source) + target_key = _normalize_key(target) + prev_target = seen.get(source_key) + if prev_target is None: + seen[source_key] = target + deduped.append(VocabularyReplacement(source=source, target=target)) + continue + if _normalize_key(prev_target) != target_key: + raise ValueError(f"vocabulary.replacements has conflicting entries for '{source}'") + return deduped + + +def _validate_terms(value: list[str]) -> list[str]: + deduped: list[str] = [] + seen: set[str] = set() + for i, term in enumerate(value): + cleaned = term.strip() + if not cleaned: + raise ValueError(f"vocabulary.terms[{i}] cannot be empty") + if "\n" in cleaned: + raise ValueError(f"vocabulary.terms[{i}] cannot contain newlines") + if any(ch in cleaned for ch in WILDCARD_CHARS): + raise ValueError(f"vocabulary.terms[{i}] cannot contain wildcard characters") + key = _normalize_key(cleaned) + if key in seen: + continue + seen.add(key) + deduped.append(cleaned) + return deduped + + +def _normalize_key(value: str) -> str: + return " ".join(value.casefold().split()) diff --git a/src/leld.py b/src/leld.py index 7bc41db..cb9f278 100755 --- a/src/leld.py +++ b/src/leld.py @@ -3,6 +3,7 @@ from __future__ import annotations import argparse import errno +import inspect import json import logging import os @@ -19,6 +20,7 @@ from constants import RECORD_TIMEOUT_SEC, STT_LANGUAGE from desktop import get_desktop_adapter from recorder import start_recording as start_audio_recording from recorder import stop_recording as stop_audio_recording +from vocabulary import VocabularyEngine class State: @@ -68,9 +70,10 @@ class Daemon: cfg.stt.model, cfg.stt.device, ) - self.ai_enabled = cfg.ai.enabled self.ai_processor: LlamaProcessor | None = None self.log_transcript = cfg.logging.log_transcript or verbose + self.vocabulary = VocabularyEngine(cfg.vocabulary, cfg.domain_inference) + self._stt_hint_kwargs_cache: dict[str, Any] | None = None def set_state(self, state: str): with self.lock: @@ -190,18 +193,25 @@ class Daemon: else: logging.info("stt produced %d chars", len(text)) - if self.ai_enabled and not self._shutdown_requested.is_set(): + domain = self.vocabulary.infer_domain(text) + if not self._shutdown_requested.is_set(): self.set_state(State.PROCESSING) logging.info("ai processing started") try: processor = self._get_ai_processor() - ai_text = processor.process(text) + ai_text = processor.process( + text, + lang=STT_LANGUAGE, + dictionary_context=self.vocabulary.build_ai_dictionary_context(), + domain_name=domain.name, + domain_confidence=domain.confidence, + ) if ai_text and ai_text.strip(): text = ai_text.strip() except Exception as exc: logging.error("ai process failed: %s", exc) - else: - logging.info("ai processing disabled") + + text = self.vocabulary.apply_deterministic_replacements(text).strip() if self.log_transcript: logging.info("processed: %s", text) @@ -251,7 +261,12 @@ class Daemon: return self.get_state() == State.IDLE def _transcribe(self, audio) -> str: - segments, _info = self.model.transcribe(audio, language=STT_LANGUAGE, vad_filter=True) + kwargs: dict[str, Any] = { + "language": STT_LANGUAGE, + "vad_filter": True, + } + kwargs.update(self._stt_hint_kwargs()) + segments, _info = self.model.transcribe(audio, **kwargs) parts = [] for seg in segments: text = (seg.text or "").strip() @@ -264,6 +279,33 @@ class Daemon: self.ai_processor = LlamaProcessor(verbose=self.verbose) return self.ai_processor + def _stt_hint_kwargs(self) -> dict[str, Any]: + if self._stt_hint_kwargs_cache is not None: + return self._stt_hint_kwargs_cache + + hotwords, initial_prompt = self.vocabulary.build_stt_hints() + if not hotwords and not initial_prompt: + self._stt_hint_kwargs_cache = {} + return self._stt_hint_kwargs_cache + + try: + signature = inspect.signature(self.model.transcribe) + except (TypeError, ValueError): + logging.debug("stt signature inspection failed; skipping hints") + self._stt_hint_kwargs_cache = {} + return self._stt_hint_kwargs_cache + + params = signature.parameters + kwargs: dict[str, Any] = {} + if hotwords and "hotwords" in params: + kwargs["hotwords"] = hotwords + if initial_prompt and "initial_prompt" in params: + kwargs["initial_prompt"] = initial_prompt + if not kwargs: + logging.debug("stt hint arguments are not supported by this whisper runtime") + self._stt_hint_kwargs_cache = kwargs + return self._stt_hint_kwargs_cache + def _read_lock_pid(lock_file) -> str: lock_file.seek(0) diff --git a/src/vocabulary.py b/src/vocabulary.py new file mode 100644 index 0000000..1a7ebb2 --- /dev/null +++ b/src/vocabulary.py @@ -0,0 +1,280 @@ +from __future__ import annotations + +import re +from dataclasses import dataclass +from typing import Iterable + +from config import DomainInferenceConfig, VocabularyConfig + + +DOMAIN_GENERAL = "general" +DOMAIN_PERSONAL_NAMES = "personal_names" +DOMAIN_SOFTWARE_DEV = "software_dev" +DOMAIN_OPS_INFRA = "ops_infra" +DOMAIN_BUSINESS = "business" +DOMAIN_MEDICAL_LEGAL = "medical_legal" + +DOMAIN_ORDER = ( + DOMAIN_PERSONAL_NAMES, + DOMAIN_SOFTWARE_DEV, + DOMAIN_OPS_INFRA, + DOMAIN_BUSINESS, + DOMAIN_MEDICAL_LEGAL, +) + +DOMAIN_KEYWORDS = { + DOMAIN_SOFTWARE_DEV: { + "api", + "bug", + "code", + "commit", + "docker", + "function", + "git", + "github", + "javascript", + "python", + "refactor", + "repository", + "typescript", + "unit", + "test", + }, + DOMAIN_OPS_INFRA: { + "cluster", + "container", + "deploy", + "deployment", + "incident", + "kubernetes", + "monitoring", + "nginx", + "pod", + "prod", + "service", + "systemd", + "terraform", + }, + DOMAIN_BUSINESS: { + "budget", + "client", + "deadline", + "finance", + "invoice", + "meeting", + "milestone", + "project", + "quarter", + "roadmap", + "sales", + "stakeholder", + }, + DOMAIN_MEDICAL_LEGAL: { + "agreement", + "case", + "claim", + "compliance", + "contract", + "diagnosis", + "liability", + "patient", + "prescription", + "regulation", + "symptom", + "treatment", + }, +} + +DOMAIN_PHRASES = { + DOMAIN_SOFTWARE_DEV: ("pull request", "code review", "integration test"), + DOMAIN_OPS_INFRA: ("on call", "service restart", "roll back"), + DOMAIN_BUSINESS: ("follow up", "action items", "meeting notes"), + DOMAIN_MEDICAL_LEGAL: ("terms and conditions", "medical record", "legal review"), +} + +GREETING_TOKENS = {"hello", "hi", "hey", "good morning", "good afternoon", "good evening"} + + +@dataclass(frozen=True) +class DomainResult: + name: str + confidence: float + + +@dataclass(frozen=True) +class _ReplacementView: + source: str + target: str + + +class VocabularyEngine: + def __init__(self, vocab_cfg: VocabularyConfig, domain_cfg: DomainInferenceConfig): + self._replacements = [_ReplacementView(r.source, r.target) for r in vocab_cfg.replacements] + self._terms = list(vocab_cfg.terms) + self._domain_enabled = bool(domain_cfg.enabled) + + self._replacement_map = { + _normalize_key(rule.source): rule.target for rule in self._replacements + } + self._replacement_pattern = _build_replacement_pattern(rule.source for rule in self._replacements) + + # Keep hint payload bounded so model prompts do not balloon. + self._stt_hotwords = self._build_stt_hotwords(limit=128, char_budget=1024) + self._stt_initial_prompt = self._build_stt_initial_prompt(char_budget=600) + + def has_dictionary(self) -> bool: + return bool(self._replacements or self._terms) + + def apply_deterministic_replacements(self, text: str) -> str: + if not text or self._replacement_pattern is None: + return text + + def _replace(match: re.Match[str]) -> str: + source_text = match.group(0) + key = _normalize_key(source_text) + return self._replacement_map.get(key, source_text) + + return self._replacement_pattern.sub(_replace, text) + + def build_stt_hints(self) -> tuple[str, str]: + return self._stt_hotwords, self._stt_initial_prompt + + def build_ai_dictionary_context(self, max_lines: int = 80, char_budget: int = 1500) -> str: + lines: list[str] = [] + for rule in self._replacements: + lines.append(f"replace: {rule.source} -> {rule.target}") + for term in self._terms: + lines.append(f"prefer: {term}") + + if not lines: + return "" + + out: list[str] = [] + used = 0 + for line in lines: + if len(out) >= max_lines: + break + addition = len(line) + (1 if out else 0) + if used + addition > char_budget: + break + out.append(line) + used += addition + return "\n".join(out) + + def infer_domain(self, text: str) -> DomainResult: + if not self._domain_enabled: + return DomainResult(name=DOMAIN_GENERAL, confidence=0.0) + + normalized = text.casefold() + tokens = re.findall(r"[a-z0-9+#./_-]+", normalized) + if not tokens: + return DomainResult(name=DOMAIN_GENERAL, confidence=0.0) + + scores = {domain: 0 for domain in DOMAIN_ORDER} + for token in tokens: + for domain, keywords in DOMAIN_KEYWORDS.items(): + if token in keywords: + scores[domain] += 2 + + for domain, phrases in DOMAIN_PHRASES.items(): + for phrase in phrases: + if phrase in normalized: + scores[domain] += 2 + + if any(token in GREETING_TOKENS for token in tokens): + scores[DOMAIN_PERSONAL_NAMES] += 1 + + # Boost domains from configured dictionary terms and replacement targets. + dictionary_tokens = self._dictionary_tokens() + for token in dictionary_tokens: + for domain, keywords in DOMAIN_KEYWORDS.items(): + if token in keywords and token in tokens: + scores[domain] += 1 + + top_domain = DOMAIN_GENERAL + top_score = 0 + total_score = 0 + for domain in DOMAIN_ORDER: + score = scores[domain] + total_score += score + if score > top_score: + top_score = score + top_domain = domain + + if top_score < 2 or total_score == 0: + return DomainResult(name=DOMAIN_GENERAL, confidence=0.0) + + confidence = top_score / total_score + if confidence < 0.45: + return DomainResult(name=DOMAIN_GENERAL, confidence=0.0) + + return DomainResult(name=top_domain, confidence=round(confidence, 2)) + + def _build_stt_hotwords(self, *, limit: int, char_budget: int) -> str: + items = _dedupe_preserve_order( + [rule.target for rule in self._replacements] + self._terms + ) + words: list[str] = [] + used = 0 + for item in items: + if len(words) >= limit: + break + addition = len(item) + (2 if words else 0) + if used + addition > char_budget: + break + words.append(item) + used += addition + return ", ".join(words) + + def _build_stt_initial_prompt(self, *, char_budget: int) -> str: + if not self._stt_hotwords: + return "" + prefix = "Preferred vocabulary: " + available = max(char_budget - len(prefix), 0) + hotwords = self._stt_hotwords[:available].rstrip(", ") + if not hotwords: + return "" + return prefix + hotwords + + def _dictionary_tokens(self) -> set[str]: + values: list[str] = [] + for rule in self._replacements: + values.append(rule.source) + values.append(rule.target) + values.extend(self._terms) + + tokens: set[str] = set() + for value in values: + for token in re.findall(r"[a-z0-9+#./_-]+", value.casefold()): + tokens.add(token) + return tokens + + +def _build_replacement_pattern(sources: Iterable[str]) -> re.Pattern[str] | None: + unique_sources = _dedupe_preserve_order(list(sources)) + if not unique_sources: + return None + + unique_sources.sort(key=lambda item: (-len(item), item.casefold())) + escaped = [re.escape(item) for item in unique_sources] + pattern = r"(? list[str]: + out: list[str] = [] + seen: set[str] = set() + for value in values: + cleaned = value.strip() + if not cleaned: + continue + key = _normalize_key(cleaned) + if key in seen: + continue + seen.add(key) + out.append(cleaned) + return out + + +def _normalize_key(value: str) -> str: + return " ".join(value.casefold().split()) diff --git a/tests/test_config.py b/tests/test_config.py index 26a6fd0..84a2768 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -27,6 +27,12 @@ class ConfigTests(unittest.TestCase): self.assertEqual(cfg.injection.backend, "clipboard") self.assertTrue(cfg.ai.enabled) self.assertFalse(cfg.logging.log_transcript) + self.assertEqual(cfg.vocabulary.replacements, []) + self.assertEqual(cfg.vocabulary.terms, []) + self.assertEqual(cfg.vocabulary.max_rules, 500) + self.assertEqual(cfg.vocabulary.max_terms, 500) + self.assertTrue(cfg.domain_inference.enabled) + self.assertEqual(cfg.domain_inference.mode, "auto") def test_loads_nested_config(self): payload = { @@ -36,6 +42,16 @@ class ConfigTests(unittest.TestCase): "injection": {"backend": "injection"}, "ai": {"enabled": False}, "logging": {"log_transcript": True}, + "vocabulary": { + "replacements": [ + {"from": "Martha", "to": "Marta"}, + {"from": "docker", "to": "Docker"}, + ], + "terms": ["Systemd", "Kubernetes"], + "max_rules": 100, + "max_terms": 200, + }, + "domain_inference": {"enabled": True, "mode": "auto"}, } with tempfile.TemporaryDirectory() as td: path = Path(td) / "config.json" @@ -50,6 +66,14 @@ class ConfigTests(unittest.TestCase): self.assertEqual(cfg.injection.backend, "injection") self.assertFalse(cfg.ai.enabled) self.assertTrue(cfg.logging.log_transcript) + self.assertEqual(cfg.vocabulary.max_rules, 100) + self.assertEqual(cfg.vocabulary.max_terms, 200) + self.assertEqual(len(cfg.vocabulary.replacements), 2) + self.assertEqual(cfg.vocabulary.replacements[0].source, "Martha") + self.assertEqual(cfg.vocabulary.replacements[0].target, "Marta") + self.assertEqual(cfg.vocabulary.terms, ["Systemd", "Kubernetes"]) + self.assertTrue(cfg.domain_inference.enabled) + self.assertEqual(cfg.domain_inference.mode, "auto") def test_loads_legacy_keys(self): payload = { @@ -74,6 +98,7 @@ class ConfigTests(unittest.TestCase): self.assertEqual(cfg.injection.backend, "clipboard") self.assertFalse(cfg.ai.enabled) self.assertTrue(cfg.logging.log_transcript) + self.assertEqual(cfg.vocabulary.replacements, []) def test_invalid_injection_backend_raises(self): payload = {"injection": {"backend": "invalid"}} @@ -93,6 +118,65 @@ class ConfigTests(unittest.TestCase): with self.assertRaisesRegex(ValueError, "logging.log_transcript"): load(str(path)) + def test_conflicting_replacements_raise(self): + payload = { + "vocabulary": { + "replacements": [ + {"from": "Martha", "to": "Marta"}, + {"from": "martha", "to": "Martha"}, + ] + } + } + with tempfile.TemporaryDirectory() as td: + path = Path(td) / "config.json" + path.write_text(json.dumps(payload), encoding="utf-8") + + with self.assertRaisesRegex(ValueError, "conflicting"): + load(str(path)) + + def test_duplicate_rules_and_terms_are_deduplicated(self): + payload = { + "vocabulary": { + "replacements": [ + {"from": "docker", "to": "Docker"}, + {"from": "DOCKER", "to": "Docker"}, + ], + "terms": ["Systemd", "systemd"], + } + } + with tempfile.TemporaryDirectory() as td: + path = Path(td) / "config.json" + path.write_text(json.dumps(payload), encoding="utf-8") + + cfg = load(str(path)) + + self.assertEqual(len(cfg.vocabulary.replacements), 1) + self.assertEqual(cfg.vocabulary.replacements[0].source, "docker") + self.assertEqual(cfg.vocabulary.replacements[0].target, "Docker") + self.assertEqual(cfg.vocabulary.terms, ["Systemd"]) + + def test_wildcard_term_raises(self): + payload = { + "vocabulary": { + "terms": ["Dock*"], + } + } + with tempfile.TemporaryDirectory() as td: + path = Path(td) / "config.json" + path.write_text(json.dumps(payload), encoding="utf-8") + + with self.assertRaisesRegex(ValueError, "wildcard"): + load(str(path)) + + def test_invalid_domain_mode_raises(self): + payload = {"domain_inference": {"mode": "heuristic"}} + with tempfile.TemporaryDirectory() as td: + path = Path(td) / "config.json" + path.write_text(json.dumps(payload), encoding="utf-8") + + with self.assertRaisesRegex(ValueError, "domain_inference.mode"): + load(str(path)) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_leld.py b/tests/test_leld.py index 4cd724c..6e97f96 100644 --- a/tests/test_leld.py +++ b/tests/test_leld.py @@ -11,7 +11,7 @@ if str(SRC) not in sys.path: sys.path.insert(0, str(SRC)) import leld -from config import Config +from config import Config, VocabularyReplacement class FakeDesktop: @@ -32,8 +32,43 @@ class FakeSegment: class FakeModel: + def __init__(self, text: str = "hello world"): + self.text = text + self.last_kwargs = {} + def transcribe(self, _audio, language=None, vad_filter=None): - return [FakeSegment("hello world")], {"language": language, "vad_filter": vad_filter} + self.last_kwargs = { + "language": language, + "vad_filter": vad_filter, + } + return [FakeSegment(self.text)], self.last_kwargs + + +class FakeHintModel: + def __init__(self, text: str = "hello world"): + self.text = text + self.last_kwargs = {} + + def transcribe( + self, + _audio, + language=None, + vad_filter=None, + hotwords=None, + initial_prompt=None, + ): + self.last_kwargs = { + "language": language, + "vad_filter": vad_filter, + "hotwords": hotwords, + "initial_prompt": initial_prompt, + } + return [FakeSegment(self.text)], self.last_kwargs + + +class FakeAIProcessor: + def process(self, text, lang="en", **_kwargs): + return text class FakeAudio: @@ -48,12 +83,13 @@ class DaemonTests(unittest.TestCase): cfg.logging.log_transcript = False return cfg - @patch("leld._build_whisper_model", return_value=FakeModel()) @patch("leld.stop_audio_recording", return_value=FakeAudio(8)) @patch("leld.start_audio_recording", return_value=(object(), object())) - def test_toggle_start_stop_injects_text(self, _start_mock, _stop_mock, _model_mock): + def test_toggle_start_stop_injects_text(self, _start_mock, _stop_mock): desktop = FakeDesktop() - daemon = leld.Daemon(self._config(), desktop, verbose=False) + with patch("leld._build_whisper_model", return_value=FakeModel()): + daemon = leld.Daemon(self._config(), desktop, verbose=False) + daemon.ai_processor = FakeAIProcessor() daemon._start_stop_worker = ( lambda stream, record, trigger, process_audio: daemon._stop_and_process( stream, record, trigger, process_audio @@ -68,12 +104,13 @@ class DaemonTests(unittest.TestCase): self.assertEqual(daemon.get_state(), leld.State.IDLE) self.assertEqual(desktop.inject_calls, [("hello world", "clipboard")]) - @patch("leld._build_whisper_model", return_value=FakeModel()) @patch("leld.stop_audio_recording", return_value=FakeAudio(8)) @patch("leld.start_audio_recording", return_value=(object(), object())) - def test_shutdown_stops_recording_without_injection(self, _start_mock, _stop_mock, _model_mock): + def test_shutdown_stops_recording_without_injection(self, _start_mock, _stop_mock): desktop = FakeDesktop() - daemon = leld.Daemon(self._config(), desktop, verbose=False) + with patch("leld._build_whisper_model", return_value=FakeModel()): + daemon = leld.Daemon(self._config(), desktop, verbose=False) + daemon.ai_processor = FakeAIProcessor() daemon._start_stop_worker = ( lambda stream, record, trigger, process_audio: daemon._stop_and_process( stream, record, trigger, process_audio @@ -87,6 +124,60 @@ class DaemonTests(unittest.TestCase): self.assertEqual(daemon.get_state(), leld.State.IDLE) self.assertEqual(desktop.inject_calls, []) + @patch("leld.stop_audio_recording", return_value=FakeAudio(8)) + @patch("leld.start_audio_recording", return_value=(object(), object())) + def test_dictionary_replacement_applies_after_ai(self, _start_mock, _stop_mock): + desktop = FakeDesktop() + model = FakeModel(text="good morning martha") + cfg = self._config() + cfg.vocabulary.replacements = [VocabularyReplacement(source="Martha", target="Marta")] + + with patch("leld._build_whisper_model", return_value=model): + daemon = leld.Daemon(cfg, desktop, verbose=False) + daemon.ai_processor = FakeAIProcessor() + daemon._start_stop_worker = ( + lambda stream, record, trigger, process_audio: daemon._stop_and_process( + stream, record, trigger, process_audio + ) + ) + + daemon.toggle() + daemon.toggle() + + self.assertEqual(desktop.inject_calls, [("good morning Marta", "clipboard")]) + + def test_transcribe_skips_hints_when_model_does_not_support_them(self): + desktop = FakeDesktop() + model = FakeModel(text="hello") + cfg = self._config() + cfg.vocabulary.terms = ["Docker", "Systemd"] + + with patch("leld._build_whisper_model", return_value=model): + daemon = leld.Daemon(cfg, desktop, verbose=False) + + result = daemon._transcribe(object()) + + self.assertEqual(result, "hello") + self.assertNotIn("hotwords", model.last_kwargs) + self.assertNotIn("initial_prompt", model.last_kwargs) + + def test_transcribe_applies_hints_when_model_supports_them(self): + desktop = FakeDesktop() + model = FakeHintModel(text="hello") + cfg = self._config() + cfg.vocabulary.terms = ["Systemd"] + cfg.vocabulary.replacements = [VocabularyReplacement(source="docker", target="Docker")] + + with patch("leld._build_whisper_model", return_value=model): + daemon = leld.Daemon(cfg, desktop, verbose=False) + + result = daemon._transcribe(object()) + + self.assertEqual(result, "hello") + self.assertIn("Docker", model.last_kwargs["hotwords"]) + self.assertIn("Systemd", model.last_kwargs["hotwords"]) + self.assertIn("Preferred vocabulary", model.last_kwargs["initial_prompt"]) + class LockTests(unittest.TestCase): def test_lock_rejects_second_instance(self): diff --git a/tests/test_vocabulary.py b/tests/test_vocabulary.py new file mode 100644 index 0000000..a941751 --- /dev/null +++ b/tests/test_vocabulary.py @@ -0,0 +1,76 @@ +import sys +import unittest +from pathlib import Path + +ROOT = Path(__file__).resolve().parents[1] +SRC = ROOT / "src" +if str(SRC) not in sys.path: + sys.path.insert(0, str(SRC)) + +from config import DomainInferenceConfig, VocabularyConfig, VocabularyReplacement +from vocabulary import DOMAIN_GENERAL, VocabularyEngine + + +class VocabularyEngineTests(unittest.TestCase): + def _engine(self, replacements=None, terms=None, domain_enabled=True): + vocab = VocabularyConfig( + replacements=replacements or [], + terms=terms or [], + ) + domain = DomainInferenceConfig(enabled=domain_enabled, mode="auto") + return VocabularyEngine(vocab, domain) + + def test_boundary_aware_replacement(self): + engine = self._engine( + replacements=[VocabularyReplacement(source="Martha", target="Marta")], + ) + + text = "Martha met Marthaville and Martha." + out = engine.apply_deterministic_replacements(text) + + self.assertEqual(out, "Marta met Marthaville and Marta.") + + def test_longest_match_replacement_wins(self): + engine = self._engine( + replacements=[ + VocabularyReplacement(source="new york", target="NYC"), + VocabularyReplacement(source="york", target="Yorkshire"), + ], + ) + + out = engine.apply_deterministic_replacements("new york york") + self.assertEqual(out, "NYC Yorkshire") + + def test_stt_hints_are_bounded(self): + terms = [f"term{i}" for i in range(300)] + engine = self._engine(terms=terms) + + hotwords, prompt = engine.build_stt_hints() + + self.assertLessEqual(len(hotwords), 1024) + self.assertLessEqual(len(prompt), 600) + + def test_domain_inference_general_fallback(self): + engine = self._engine() + result = engine.infer_domain("please call me later") + + self.assertEqual(result.name, DOMAIN_GENERAL) + self.assertEqual(result.confidence, 0.0) + + def test_domain_inference_for_technical_text(self): + engine = self._engine(terms=["Docker", "Systemd"]) + result = engine.infer_domain("restart Docker and systemd service on prod") + + self.assertNotEqual(result.name, DOMAIN_GENERAL) + self.assertGreater(result.confidence, 0.0) + + def test_domain_inference_can_be_disabled(self): + engine = self._engine(domain_enabled=False) + result = engine.infer_domain("please restart docker") + + self.assertEqual(result.name, DOMAIN_GENERAL) + self.assertEqual(result.confidence, 0.0) + + +if __name__ == "__main__": + unittest.main()