Add vocabulary correction pipeline and example config
This commit is contained in:
parent
f9224621fa
commit
c3503fbbde
9 changed files with 865 additions and 23 deletions
39
README.md
39
README.md
|
|
@ -1,6 +1,6 @@
|
||||||
# lel
|
# lel
|
||||||
|
|
||||||
Python X11 STT daemon that records audio, runs Whisper, and injects text. It can optionally run local AI post-processing before injection.
|
Python X11 STT daemon that records audio, runs Whisper, applies local AI cleanup, and injects text.
|
||||||
|
|
||||||
## Requirements
|
## Requirements
|
||||||
|
|
||||||
|
|
@ -92,21 +92,50 @@ Create `~/.config/lel/config.json`:
|
||||||
"stt": { "model": "base", "device": "cpu" },
|
"stt": { "model": "base", "device": "cpu" },
|
||||||
"injection": { "backend": "clipboard" },
|
"injection": { "backend": "clipboard" },
|
||||||
"ai": { "enabled": true },
|
"ai": { "enabled": true },
|
||||||
"logging": { "log_transcript": false }
|
"logging": { "log_transcript": false },
|
||||||
|
"vocabulary": {
|
||||||
|
"replacements": [
|
||||||
|
{ "from": "Martha", "to": "Marta" },
|
||||||
|
{ "from": "docker", "to": "Docker" }
|
||||||
|
],
|
||||||
|
"terms": ["Systemd", "Kubernetes"],
|
||||||
|
"max_rules": 500,
|
||||||
|
"max_terms": 500
|
||||||
|
},
|
||||||
|
"domain_inference": { "enabled": true, "mode": "auto" }
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
Recording input can be a device index (preferred) or a substring of the device
|
Recording input can be a device index (preferred) or a substring of the device
|
||||||
name.
|
name.
|
||||||
|
|
||||||
`ai.enabled` controls local cleanup. When enabled, the LLM model is downloaded
|
`ai.enabled` is accepted for compatibility but currently has no runtime effect.
|
||||||
on first use to `~/.cache/lel/models/` and uses the locked Llama-3.2-3B GGUF
|
AI cleanup is always enabled and uses the locked local Llama-3.2-3B GGUF model
|
||||||
model.
|
downloaded to `~/.cache/lel/models/` on first use.
|
||||||
|
|
||||||
`logging.log_transcript` controls whether recognized/processed text is written
|
`logging.log_transcript` controls whether recognized/processed text is written
|
||||||
to logs. This is disabled by default. `-v/--verbose` also enables transcript
|
to logs. This is disabled by default. `-v/--verbose` also enables transcript
|
||||||
logging and llama.cpp logs; llama logs are prefixed with `llama::`.
|
logging and llama.cpp logs; llama logs are prefixed with `llama::`.
|
||||||
|
|
||||||
|
Vocabulary correction:
|
||||||
|
|
||||||
|
- `vocabulary.replacements` is deterministic correction (`from -> to`).
|
||||||
|
- `vocabulary.terms` is a preferred spelling list used as hinting context.
|
||||||
|
- Wildcards are intentionally rejected (`*`, `?`, `[`, `]`, `{`, `}`) to avoid ambiguous rules.
|
||||||
|
- Rules are deduplicated case-insensitively; conflicting replacements are rejected.
|
||||||
|
- Limits are bounded by `max_rules` and `max_terms`.
|
||||||
|
|
||||||
|
Domain inference:
|
||||||
|
|
||||||
|
- `domain_inference.mode` currently supports `auto`.
|
||||||
|
- Domain context is advisory only and is used to improve cleanup prompts.
|
||||||
|
- When confidence is low, it falls back to `general` context.
|
||||||
|
|
||||||
|
STT hinting:
|
||||||
|
|
||||||
|
- Vocabulary is passed to Whisper as `hotwords`/`initial_prompt` only when those
|
||||||
|
arguments are supported by the installed `faster-whisper` runtime.
|
||||||
|
|
||||||
## systemd user service
|
## systemd user service
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
|
|
||||||
50
config.example.json
Normal file
50
config.example.json
Normal file
|
|
@ -0,0 +1,50 @@
|
||||||
|
{
|
||||||
|
"daemon": {
|
||||||
|
"hotkey": "Cmd+m"
|
||||||
|
},
|
||||||
|
"recording": {
|
||||||
|
"input": ""
|
||||||
|
},
|
||||||
|
"stt": {
|
||||||
|
"model": "base",
|
||||||
|
"device": "cpu"
|
||||||
|
},
|
||||||
|
"injection": {
|
||||||
|
"backend": "clipboard"
|
||||||
|
},
|
||||||
|
"ai": {
|
||||||
|
"enabled": true
|
||||||
|
},
|
||||||
|
"logging": {
|
||||||
|
"log_transcript": true
|
||||||
|
},
|
||||||
|
"vocabulary": {
|
||||||
|
"replacements": [
|
||||||
|
{
|
||||||
|
"from": "Martha",
|
||||||
|
"to": "Marta"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"from": "docker",
|
||||||
|
"to": "Docker"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"from": "system d",
|
||||||
|
"to": "systemd"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"terms": [
|
||||||
|
"Marta",
|
||||||
|
"Docker",
|
||||||
|
"systemd",
|
||||||
|
"Kubernetes",
|
||||||
|
"PostgreSQL"
|
||||||
|
],
|
||||||
|
"max_rules": 500,
|
||||||
|
"max_terms": 500
|
||||||
|
},
|
||||||
|
"domain_inference": {
|
||||||
|
"enabled": true,
|
||||||
|
"mode": "auto"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -24,6 +24,9 @@ SYSTEM_PROMPT = (
|
||||||
"- Remove filler words (um/uh/like)\n"
|
"- Remove filler words (um/uh/like)\n"
|
||||||
"- Remove false starts\n"
|
"- Remove false starts\n"
|
||||||
"- Remove self-corrections.\n"
|
"- Remove self-corrections.\n"
|
||||||
|
"- If a <dictionary> section exists, apply only the listed corrections.\n"
|
||||||
|
"- Keep dictionary spellings exactly as provided.\n"
|
||||||
|
"- Treat domain hints as advisory only; never invent context-specific jargon.\n"
|
||||||
"- Output ONLY the cleaned text, no commentary.\n\n"
|
"- Output ONLY the cleaned text, no commentary.\n\n"
|
||||||
"Examples:\n"
|
"Examples:\n"
|
||||||
" - \"Hey, schedule that for 5 PM, I mean 4 PM\" -> \"Hey, schedule that for 4 PM\"\n"
|
" - \"Hey, schedule that for 5 PM, I mean 4 PM\" -> \"Hey, schedule that for 4 PM\"\n"
|
||||||
|
|
@ -49,9 +52,23 @@ class LlamaProcessor:
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
)
|
)
|
||||||
|
|
||||||
def process(self, text: str, lang: str = "en") -> str:
|
def process(
|
||||||
user_content = f"<transcript>{text}</transcript>"
|
self,
|
||||||
user_content = f"<language>{lang}</language>\n{user_content}"
|
text: str,
|
||||||
|
lang: str = "en",
|
||||||
|
*,
|
||||||
|
dictionary_context: str = "",
|
||||||
|
domain_name: str = "general",
|
||||||
|
domain_confidence: float = 0.0,
|
||||||
|
) -> str:
|
||||||
|
blocks = [
|
||||||
|
f"<language>{lang}</language>",
|
||||||
|
f'<domain name="{domain_name}" confidence="{domain_confidence:.2f}"/>',
|
||||||
|
]
|
||||||
|
if dictionary_context.strip():
|
||||||
|
blocks.append(f"<dictionary>\n{dictionary_context.strip()}\n</dictionary>")
|
||||||
|
blocks.append(f"<transcript>{text}</transcript>")
|
||||||
|
user_content = "\n".join(blocks)
|
||||||
response = self.client.create_chat_completion(
|
response = self.client.create_chat_completion(
|
||||||
messages=[
|
messages=[
|
||||||
{"role": "system", "content": SYSTEM_PROMPT},
|
{"role": "system", "content": SYSTEM_PROMPT},
|
||||||
|
|
|
||||||
175
src/config.py
175
src/config.py
|
|
@ -12,7 +12,11 @@ DEFAULT_HOTKEY = "Cmd+m"
|
||||||
DEFAULT_STT_MODEL = "base"
|
DEFAULT_STT_MODEL = "base"
|
||||||
DEFAULT_STT_DEVICE = "cpu"
|
DEFAULT_STT_DEVICE = "cpu"
|
||||||
DEFAULT_INJECTION_BACKEND = "clipboard"
|
DEFAULT_INJECTION_BACKEND = "clipboard"
|
||||||
|
DEFAULT_VOCAB_LIMIT = 500
|
||||||
|
DEFAULT_DOMAIN_INFERENCE_MODE = "auto"
|
||||||
ALLOWED_INJECTION_BACKENDS = {"clipboard", "injection"}
|
ALLOWED_INJECTION_BACKENDS = {"clipboard", "injection"}
|
||||||
|
ALLOWED_DOMAIN_INFERENCE_MODES = {"auto"}
|
||||||
|
WILDCARD_CHARS = set("*?[]{}")
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -46,6 +50,26 @@ class LoggingConfig:
|
||||||
log_transcript: bool = False
|
log_transcript: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class VocabularyReplacement:
|
||||||
|
source: str
|
||||||
|
target: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class VocabularyConfig:
|
||||||
|
replacements: list[VocabularyReplacement] = field(default_factory=list)
|
||||||
|
terms: list[str] = field(default_factory=list)
|
||||||
|
max_rules: int = DEFAULT_VOCAB_LIMIT
|
||||||
|
max_terms: int = DEFAULT_VOCAB_LIMIT
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DomainInferenceConfig:
|
||||||
|
enabled: bool = True
|
||||||
|
mode: str = DEFAULT_DOMAIN_INFERENCE_MODE
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Config:
|
class Config:
|
||||||
daemon: DaemonConfig = field(default_factory=DaemonConfig)
|
daemon: DaemonConfig = field(default_factory=DaemonConfig)
|
||||||
|
|
@ -54,6 +78,8 @@ class Config:
|
||||||
injection: InjectionConfig = field(default_factory=InjectionConfig)
|
injection: InjectionConfig = field(default_factory=InjectionConfig)
|
||||||
ai: AiConfig = field(default_factory=AiConfig)
|
ai: AiConfig = field(default_factory=AiConfig)
|
||||||
logging: LoggingConfig = field(default_factory=LoggingConfig)
|
logging: LoggingConfig = field(default_factory=LoggingConfig)
|
||||||
|
vocabulary: VocabularyConfig = field(default_factory=VocabularyConfig)
|
||||||
|
domain_inference: DomainInferenceConfig = field(default_factory=DomainInferenceConfig)
|
||||||
|
|
||||||
|
|
||||||
def load(path: str | None) -> Config:
|
def load(path: str | None) -> Config:
|
||||||
|
|
@ -102,10 +128,43 @@ def validate(cfg: Config) -> None:
|
||||||
if not isinstance(cfg.logging.log_transcript, bool):
|
if not isinstance(cfg.logging.log_transcript, bool):
|
||||||
raise ValueError("logging.log_transcript must be boolean")
|
raise ValueError("logging.log_transcript must be boolean")
|
||||||
|
|
||||||
|
cfg.vocabulary.max_rules = _validated_limit(cfg.vocabulary.max_rules, "vocabulary.max_rules")
|
||||||
|
cfg.vocabulary.max_terms = _validated_limit(cfg.vocabulary.max_terms, "vocabulary.max_terms")
|
||||||
|
|
||||||
|
if len(cfg.vocabulary.replacements) > cfg.vocabulary.max_rules:
|
||||||
|
raise ValueError(
|
||||||
|
f"vocabulary.replacements cannot exceed vocabulary.max_rules ({cfg.vocabulary.max_rules})"
|
||||||
|
)
|
||||||
|
if len(cfg.vocabulary.terms) > cfg.vocabulary.max_terms:
|
||||||
|
raise ValueError(
|
||||||
|
f"vocabulary.terms cannot exceed vocabulary.max_terms ({cfg.vocabulary.max_terms})"
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg.vocabulary.replacements = _validate_replacements(cfg.vocabulary.replacements)
|
||||||
|
cfg.vocabulary.terms = _validate_terms(cfg.vocabulary.terms)
|
||||||
|
|
||||||
|
if not isinstance(cfg.domain_inference.enabled, bool):
|
||||||
|
raise ValueError("domain_inference.enabled must be boolean")
|
||||||
|
mode = cfg.domain_inference.mode.strip().lower()
|
||||||
|
if mode not in ALLOWED_DOMAIN_INFERENCE_MODES:
|
||||||
|
allowed = ", ".join(sorted(ALLOWED_DOMAIN_INFERENCE_MODES))
|
||||||
|
raise ValueError(f"domain_inference.mode must be one of: {allowed}")
|
||||||
|
cfg.domain_inference.mode = mode
|
||||||
|
|
||||||
|
|
||||||
def _from_dict(data: dict[str, Any], cfg: Config) -> Config:
|
def _from_dict(data: dict[str, Any], cfg: Config) -> Config:
|
||||||
has_sections = any(
|
has_sections = any(
|
||||||
key in data for key in ("daemon", "recording", "stt", "injection", "ai", "logging")
|
key in data
|
||||||
|
for key in (
|
||||||
|
"daemon",
|
||||||
|
"recording",
|
||||||
|
"stt",
|
||||||
|
"injection",
|
||||||
|
"ai",
|
||||||
|
"logging",
|
||||||
|
"vocabulary",
|
||||||
|
"domain_inference",
|
||||||
|
)
|
||||||
)
|
)
|
||||||
if has_sections:
|
if has_sections:
|
||||||
daemon = _ensure_dict(data.get("daemon"), "daemon")
|
daemon = _ensure_dict(data.get("daemon"), "daemon")
|
||||||
|
|
@ -114,6 +173,8 @@ def _from_dict(data: dict[str, Any], cfg: Config) -> Config:
|
||||||
injection = _ensure_dict(data.get("injection"), "injection")
|
injection = _ensure_dict(data.get("injection"), "injection")
|
||||||
ai = _ensure_dict(data.get("ai"), "ai")
|
ai = _ensure_dict(data.get("ai"), "ai")
|
||||||
logging_cfg = _ensure_dict(data.get("logging"), "logging")
|
logging_cfg = _ensure_dict(data.get("logging"), "logging")
|
||||||
|
vocabulary = _ensure_dict(data.get("vocabulary"), "vocabulary")
|
||||||
|
domain_inference = _ensure_dict(data.get("domain_inference"), "domain_inference")
|
||||||
|
|
||||||
if "hotkey" in daemon:
|
if "hotkey" in daemon:
|
||||||
cfg.daemon.hotkey = _as_nonempty_str(daemon["hotkey"], "daemon.hotkey")
|
cfg.daemon.hotkey = _as_nonempty_str(daemon["hotkey"], "daemon.hotkey")
|
||||||
|
|
@ -129,6 +190,22 @@ def _from_dict(data: dict[str, Any], cfg: Config) -> Config:
|
||||||
cfg.ai.enabled = _as_bool(ai["enabled"], "ai.enabled")
|
cfg.ai.enabled = _as_bool(ai["enabled"], "ai.enabled")
|
||||||
if "log_transcript" in logging_cfg:
|
if "log_transcript" in logging_cfg:
|
||||||
cfg.logging.log_transcript = _as_bool(logging_cfg["log_transcript"], "logging.log_transcript")
|
cfg.logging.log_transcript = _as_bool(logging_cfg["log_transcript"], "logging.log_transcript")
|
||||||
|
if "replacements" in vocabulary:
|
||||||
|
cfg.vocabulary.replacements = _as_replacements(vocabulary["replacements"])
|
||||||
|
if "terms" in vocabulary:
|
||||||
|
cfg.vocabulary.terms = _as_terms(vocabulary["terms"])
|
||||||
|
if "max_rules" in vocabulary:
|
||||||
|
cfg.vocabulary.max_rules = _as_int(vocabulary["max_rules"], "vocabulary.max_rules")
|
||||||
|
if "max_terms" in vocabulary:
|
||||||
|
cfg.vocabulary.max_terms = _as_int(vocabulary["max_terms"], "vocabulary.max_terms")
|
||||||
|
if "enabled" in domain_inference:
|
||||||
|
cfg.domain_inference.enabled = _as_bool(
|
||||||
|
domain_inference["enabled"], "domain_inference.enabled"
|
||||||
|
)
|
||||||
|
if "mode" in domain_inference:
|
||||||
|
cfg.domain_inference.mode = _as_nonempty_str(
|
||||||
|
domain_inference["mode"], "domain_inference.mode"
|
||||||
|
)
|
||||||
return cfg
|
return cfg
|
||||||
|
|
||||||
if "hotkey" in data:
|
if "hotkey" in data:
|
||||||
|
|
@ -170,6 +247,12 @@ def _as_bool(value: Any, field_name: str) -> bool:
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def _as_int(value: Any, field_name: str) -> int:
|
||||||
|
if isinstance(value, bool) or not isinstance(value, int):
|
||||||
|
raise ValueError(f"{field_name} must be an integer")
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
def _as_recording_input(value: Any) -> str | int | None:
|
def _as_recording_input(value: Any) -> str | int | None:
|
||||||
if value is None:
|
if value is None:
|
||||||
return None
|
return None
|
||||||
|
|
@ -178,3 +261,93 @@ def _as_recording_input(value: Any) -> str | int | None:
|
||||||
if isinstance(value, (str, int)):
|
if isinstance(value, (str, int)):
|
||||||
return value
|
return value
|
||||||
raise ValueError("recording.input must be string, integer, or null")
|
raise ValueError("recording.input must be string, integer, or null")
|
||||||
|
|
||||||
|
|
||||||
|
def _as_replacements(value: Any) -> list[VocabularyReplacement]:
|
||||||
|
if not isinstance(value, list):
|
||||||
|
raise ValueError("vocabulary.replacements must be a list")
|
||||||
|
replacements: list[VocabularyReplacement] = []
|
||||||
|
for i, item in enumerate(value):
|
||||||
|
if not isinstance(item, dict):
|
||||||
|
raise ValueError(f"vocabulary.replacements[{i}] must be an object")
|
||||||
|
if "from" not in item:
|
||||||
|
raise ValueError(f"vocabulary.replacements[{i}].from is required")
|
||||||
|
if "to" not in item:
|
||||||
|
raise ValueError(f"vocabulary.replacements[{i}].to is required")
|
||||||
|
source = _as_nonempty_str(item["from"], f"vocabulary.replacements[{i}].from")
|
||||||
|
target = _as_nonempty_str(item["to"], f"vocabulary.replacements[{i}].to")
|
||||||
|
replacements.append(VocabularyReplacement(source=source, target=target))
|
||||||
|
return replacements
|
||||||
|
|
||||||
|
|
||||||
|
def _as_terms(value: Any) -> list[str]:
|
||||||
|
if not isinstance(value, list):
|
||||||
|
raise ValueError("vocabulary.terms must be a list")
|
||||||
|
terms: list[str] = []
|
||||||
|
for i, item in enumerate(value):
|
||||||
|
terms.append(_as_nonempty_str(item, f"vocabulary.terms[{i}]"))
|
||||||
|
return terms
|
||||||
|
|
||||||
|
|
||||||
|
def _validated_limit(value: int, field_name: str) -> int:
|
||||||
|
if isinstance(value, bool) or not isinstance(value, int):
|
||||||
|
raise ValueError(f"{field_name} must be an integer")
|
||||||
|
if value <= 0:
|
||||||
|
raise ValueError(f"{field_name} must be positive")
|
||||||
|
if value > 5000:
|
||||||
|
raise ValueError(f"{field_name} cannot exceed 5000")
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_replacements(value: list[VocabularyReplacement]) -> list[VocabularyReplacement]:
|
||||||
|
deduped: list[VocabularyReplacement] = []
|
||||||
|
seen: dict[str, str] = {}
|
||||||
|
for i, item in enumerate(value):
|
||||||
|
source = item.source.strip()
|
||||||
|
target = item.target.strip()
|
||||||
|
if not source:
|
||||||
|
raise ValueError(f"vocabulary.replacements[{i}].from cannot be empty")
|
||||||
|
if not target:
|
||||||
|
raise ValueError(f"vocabulary.replacements[{i}].to cannot be empty")
|
||||||
|
if source == target:
|
||||||
|
raise ValueError(f"vocabulary.replacements[{i}] cannot map a term to itself")
|
||||||
|
if "\n" in source or "\n" in target:
|
||||||
|
raise ValueError(f"vocabulary.replacements[{i}] cannot contain newlines")
|
||||||
|
if any(ch in source for ch in WILDCARD_CHARS):
|
||||||
|
raise ValueError(
|
||||||
|
f"vocabulary.replacements[{i}].from cannot contain wildcard characters"
|
||||||
|
)
|
||||||
|
|
||||||
|
source_key = _normalize_key(source)
|
||||||
|
target_key = _normalize_key(target)
|
||||||
|
prev_target = seen.get(source_key)
|
||||||
|
if prev_target is None:
|
||||||
|
seen[source_key] = target
|
||||||
|
deduped.append(VocabularyReplacement(source=source, target=target))
|
||||||
|
continue
|
||||||
|
if _normalize_key(prev_target) != target_key:
|
||||||
|
raise ValueError(f"vocabulary.replacements has conflicting entries for '{source}'")
|
||||||
|
return deduped
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_terms(value: list[str]) -> list[str]:
|
||||||
|
deduped: list[str] = []
|
||||||
|
seen: set[str] = set()
|
||||||
|
for i, term in enumerate(value):
|
||||||
|
cleaned = term.strip()
|
||||||
|
if not cleaned:
|
||||||
|
raise ValueError(f"vocabulary.terms[{i}] cannot be empty")
|
||||||
|
if "\n" in cleaned:
|
||||||
|
raise ValueError(f"vocabulary.terms[{i}] cannot contain newlines")
|
||||||
|
if any(ch in cleaned for ch in WILDCARD_CHARS):
|
||||||
|
raise ValueError(f"vocabulary.terms[{i}] cannot contain wildcard characters")
|
||||||
|
key = _normalize_key(cleaned)
|
||||||
|
if key in seen:
|
||||||
|
continue
|
||||||
|
seen.add(key)
|
||||||
|
deduped.append(cleaned)
|
||||||
|
return deduped
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_key(value: str) -> str:
|
||||||
|
return " ".join(value.casefold().split())
|
||||||
|
|
|
||||||
54
src/leld.py
54
src/leld.py
|
|
@ -3,6 +3,7 @@ from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import errno
|
import errno
|
||||||
|
import inspect
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
|
@ -19,6 +20,7 @@ from constants import RECORD_TIMEOUT_SEC, STT_LANGUAGE
|
||||||
from desktop import get_desktop_adapter
|
from desktop import get_desktop_adapter
|
||||||
from recorder import start_recording as start_audio_recording
|
from recorder import start_recording as start_audio_recording
|
||||||
from recorder import stop_recording as stop_audio_recording
|
from recorder import stop_recording as stop_audio_recording
|
||||||
|
from vocabulary import VocabularyEngine
|
||||||
|
|
||||||
|
|
||||||
class State:
|
class State:
|
||||||
|
|
@ -68,9 +70,10 @@ class Daemon:
|
||||||
cfg.stt.model,
|
cfg.stt.model,
|
||||||
cfg.stt.device,
|
cfg.stt.device,
|
||||||
)
|
)
|
||||||
self.ai_enabled = cfg.ai.enabled
|
|
||||||
self.ai_processor: LlamaProcessor | None = None
|
self.ai_processor: LlamaProcessor | None = None
|
||||||
self.log_transcript = cfg.logging.log_transcript or verbose
|
self.log_transcript = cfg.logging.log_transcript or verbose
|
||||||
|
self.vocabulary = VocabularyEngine(cfg.vocabulary, cfg.domain_inference)
|
||||||
|
self._stt_hint_kwargs_cache: dict[str, Any] | None = None
|
||||||
|
|
||||||
def set_state(self, state: str):
|
def set_state(self, state: str):
|
||||||
with self.lock:
|
with self.lock:
|
||||||
|
|
@ -190,18 +193,25 @@ class Daemon:
|
||||||
else:
|
else:
|
||||||
logging.info("stt produced %d chars", len(text))
|
logging.info("stt produced %d chars", len(text))
|
||||||
|
|
||||||
if self.ai_enabled and not self._shutdown_requested.is_set():
|
domain = self.vocabulary.infer_domain(text)
|
||||||
|
if not self._shutdown_requested.is_set():
|
||||||
self.set_state(State.PROCESSING)
|
self.set_state(State.PROCESSING)
|
||||||
logging.info("ai processing started")
|
logging.info("ai processing started")
|
||||||
try:
|
try:
|
||||||
processor = self._get_ai_processor()
|
processor = self._get_ai_processor()
|
||||||
ai_text = processor.process(text)
|
ai_text = processor.process(
|
||||||
|
text,
|
||||||
|
lang=STT_LANGUAGE,
|
||||||
|
dictionary_context=self.vocabulary.build_ai_dictionary_context(),
|
||||||
|
domain_name=domain.name,
|
||||||
|
domain_confidence=domain.confidence,
|
||||||
|
)
|
||||||
if ai_text and ai_text.strip():
|
if ai_text and ai_text.strip():
|
||||||
text = ai_text.strip()
|
text = ai_text.strip()
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logging.error("ai process failed: %s", exc)
|
logging.error("ai process failed: %s", exc)
|
||||||
else:
|
|
||||||
logging.info("ai processing disabled")
|
text = self.vocabulary.apply_deterministic_replacements(text).strip()
|
||||||
|
|
||||||
if self.log_transcript:
|
if self.log_transcript:
|
||||||
logging.info("processed: %s", text)
|
logging.info("processed: %s", text)
|
||||||
|
|
@ -251,7 +261,12 @@ class Daemon:
|
||||||
return self.get_state() == State.IDLE
|
return self.get_state() == State.IDLE
|
||||||
|
|
||||||
def _transcribe(self, audio) -> str:
|
def _transcribe(self, audio) -> str:
|
||||||
segments, _info = self.model.transcribe(audio, language=STT_LANGUAGE, vad_filter=True)
|
kwargs: dict[str, Any] = {
|
||||||
|
"language": STT_LANGUAGE,
|
||||||
|
"vad_filter": True,
|
||||||
|
}
|
||||||
|
kwargs.update(self._stt_hint_kwargs())
|
||||||
|
segments, _info = self.model.transcribe(audio, **kwargs)
|
||||||
parts = []
|
parts = []
|
||||||
for seg in segments:
|
for seg in segments:
|
||||||
text = (seg.text or "").strip()
|
text = (seg.text or "").strip()
|
||||||
|
|
@ -264,6 +279,33 @@ class Daemon:
|
||||||
self.ai_processor = LlamaProcessor(verbose=self.verbose)
|
self.ai_processor = LlamaProcessor(verbose=self.verbose)
|
||||||
return self.ai_processor
|
return self.ai_processor
|
||||||
|
|
||||||
|
def _stt_hint_kwargs(self) -> dict[str, Any]:
|
||||||
|
if self._stt_hint_kwargs_cache is not None:
|
||||||
|
return self._stt_hint_kwargs_cache
|
||||||
|
|
||||||
|
hotwords, initial_prompt = self.vocabulary.build_stt_hints()
|
||||||
|
if not hotwords and not initial_prompt:
|
||||||
|
self._stt_hint_kwargs_cache = {}
|
||||||
|
return self._stt_hint_kwargs_cache
|
||||||
|
|
||||||
|
try:
|
||||||
|
signature = inspect.signature(self.model.transcribe)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
logging.debug("stt signature inspection failed; skipping hints")
|
||||||
|
self._stt_hint_kwargs_cache = {}
|
||||||
|
return self._stt_hint_kwargs_cache
|
||||||
|
|
||||||
|
params = signature.parameters
|
||||||
|
kwargs: dict[str, Any] = {}
|
||||||
|
if hotwords and "hotwords" in params:
|
||||||
|
kwargs["hotwords"] = hotwords
|
||||||
|
if initial_prompt and "initial_prompt" in params:
|
||||||
|
kwargs["initial_prompt"] = initial_prompt
|
||||||
|
if not kwargs:
|
||||||
|
logging.debug("stt hint arguments are not supported by this whisper runtime")
|
||||||
|
self._stt_hint_kwargs_cache = kwargs
|
||||||
|
return self._stt_hint_kwargs_cache
|
||||||
|
|
||||||
|
|
||||||
def _read_lock_pid(lock_file) -> str:
|
def _read_lock_pid(lock_file) -> str:
|
||||||
lock_file.seek(0)
|
lock_file.seek(0)
|
||||||
|
|
|
||||||
280
src/vocabulary.py
Normal file
280
src/vocabulary.py
Normal file
|
|
@ -0,0 +1,280 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Iterable
|
||||||
|
|
||||||
|
from config import DomainInferenceConfig, VocabularyConfig
|
||||||
|
|
||||||
|
|
||||||
|
DOMAIN_GENERAL = "general"
|
||||||
|
DOMAIN_PERSONAL_NAMES = "personal_names"
|
||||||
|
DOMAIN_SOFTWARE_DEV = "software_dev"
|
||||||
|
DOMAIN_OPS_INFRA = "ops_infra"
|
||||||
|
DOMAIN_BUSINESS = "business"
|
||||||
|
DOMAIN_MEDICAL_LEGAL = "medical_legal"
|
||||||
|
|
||||||
|
DOMAIN_ORDER = (
|
||||||
|
DOMAIN_PERSONAL_NAMES,
|
||||||
|
DOMAIN_SOFTWARE_DEV,
|
||||||
|
DOMAIN_OPS_INFRA,
|
||||||
|
DOMAIN_BUSINESS,
|
||||||
|
DOMAIN_MEDICAL_LEGAL,
|
||||||
|
)
|
||||||
|
|
||||||
|
DOMAIN_KEYWORDS = {
|
||||||
|
DOMAIN_SOFTWARE_DEV: {
|
||||||
|
"api",
|
||||||
|
"bug",
|
||||||
|
"code",
|
||||||
|
"commit",
|
||||||
|
"docker",
|
||||||
|
"function",
|
||||||
|
"git",
|
||||||
|
"github",
|
||||||
|
"javascript",
|
||||||
|
"python",
|
||||||
|
"refactor",
|
||||||
|
"repository",
|
||||||
|
"typescript",
|
||||||
|
"unit",
|
||||||
|
"test",
|
||||||
|
},
|
||||||
|
DOMAIN_OPS_INFRA: {
|
||||||
|
"cluster",
|
||||||
|
"container",
|
||||||
|
"deploy",
|
||||||
|
"deployment",
|
||||||
|
"incident",
|
||||||
|
"kubernetes",
|
||||||
|
"monitoring",
|
||||||
|
"nginx",
|
||||||
|
"pod",
|
||||||
|
"prod",
|
||||||
|
"service",
|
||||||
|
"systemd",
|
||||||
|
"terraform",
|
||||||
|
},
|
||||||
|
DOMAIN_BUSINESS: {
|
||||||
|
"budget",
|
||||||
|
"client",
|
||||||
|
"deadline",
|
||||||
|
"finance",
|
||||||
|
"invoice",
|
||||||
|
"meeting",
|
||||||
|
"milestone",
|
||||||
|
"project",
|
||||||
|
"quarter",
|
||||||
|
"roadmap",
|
||||||
|
"sales",
|
||||||
|
"stakeholder",
|
||||||
|
},
|
||||||
|
DOMAIN_MEDICAL_LEGAL: {
|
||||||
|
"agreement",
|
||||||
|
"case",
|
||||||
|
"claim",
|
||||||
|
"compliance",
|
||||||
|
"contract",
|
||||||
|
"diagnosis",
|
||||||
|
"liability",
|
||||||
|
"patient",
|
||||||
|
"prescription",
|
||||||
|
"regulation",
|
||||||
|
"symptom",
|
||||||
|
"treatment",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
DOMAIN_PHRASES = {
|
||||||
|
DOMAIN_SOFTWARE_DEV: ("pull request", "code review", "integration test"),
|
||||||
|
DOMAIN_OPS_INFRA: ("on call", "service restart", "roll back"),
|
||||||
|
DOMAIN_BUSINESS: ("follow up", "action items", "meeting notes"),
|
||||||
|
DOMAIN_MEDICAL_LEGAL: ("terms and conditions", "medical record", "legal review"),
|
||||||
|
}
|
||||||
|
|
||||||
|
GREETING_TOKENS = {"hello", "hi", "hey", "good morning", "good afternoon", "good evening"}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class DomainResult:
|
||||||
|
name: str
|
||||||
|
confidence: float
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class _ReplacementView:
|
||||||
|
source: str
|
||||||
|
target: str
|
||||||
|
|
||||||
|
|
||||||
|
class VocabularyEngine:
|
||||||
|
def __init__(self, vocab_cfg: VocabularyConfig, domain_cfg: DomainInferenceConfig):
|
||||||
|
self._replacements = [_ReplacementView(r.source, r.target) for r in vocab_cfg.replacements]
|
||||||
|
self._terms = list(vocab_cfg.terms)
|
||||||
|
self._domain_enabled = bool(domain_cfg.enabled)
|
||||||
|
|
||||||
|
self._replacement_map = {
|
||||||
|
_normalize_key(rule.source): rule.target for rule in self._replacements
|
||||||
|
}
|
||||||
|
self._replacement_pattern = _build_replacement_pattern(rule.source for rule in self._replacements)
|
||||||
|
|
||||||
|
# Keep hint payload bounded so model prompts do not balloon.
|
||||||
|
self._stt_hotwords = self._build_stt_hotwords(limit=128, char_budget=1024)
|
||||||
|
self._stt_initial_prompt = self._build_stt_initial_prompt(char_budget=600)
|
||||||
|
|
||||||
|
def has_dictionary(self) -> bool:
|
||||||
|
return bool(self._replacements or self._terms)
|
||||||
|
|
||||||
|
def apply_deterministic_replacements(self, text: str) -> str:
|
||||||
|
if not text or self._replacement_pattern is None:
|
||||||
|
return text
|
||||||
|
|
||||||
|
def _replace(match: re.Match[str]) -> str:
|
||||||
|
source_text = match.group(0)
|
||||||
|
key = _normalize_key(source_text)
|
||||||
|
return self._replacement_map.get(key, source_text)
|
||||||
|
|
||||||
|
return self._replacement_pattern.sub(_replace, text)
|
||||||
|
|
||||||
|
def build_stt_hints(self) -> tuple[str, str]:
|
||||||
|
return self._stt_hotwords, self._stt_initial_prompt
|
||||||
|
|
||||||
|
def build_ai_dictionary_context(self, max_lines: int = 80, char_budget: int = 1500) -> str:
|
||||||
|
lines: list[str] = []
|
||||||
|
for rule in self._replacements:
|
||||||
|
lines.append(f"replace: {rule.source} -> {rule.target}")
|
||||||
|
for term in self._terms:
|
||||||
|
lines.append(f"prefer: {term}")
|
||||||
|
|
||||||
|
if not lines:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
out: list[str] = []
|
||||||
|
used = 0
|
||||||
|
for line in lines:
|
||||||
|
if len(out) >= max_lines:
|
||||||
|
break
|
||||||
|
addition = len(line) + (1 if out else 0)
|
||||||
|
if used + addition > char_budget:
|
||||||
|
break
|
||||||
|
out.append(line)
|
||||||
|
used += addition
|
||||||
|
return "\n".join(out)
|
||||||
|
|
||||||
|
def infer_domain(self, text: str) -> DomainResult:
|
||||||
|
if not self._domain_enabled:
|
||||||
|
return DomainResult(name=DOMAIN_GENERAL, confidence=0.0)
|
||||||
|
|
||||||
|
normalized = text.casefold()
|
||||||
|
tokens = re.findall(r"[a-z0-9+#./_-]+", normalized)
|
||||||
|
if not tokens:
|
||||||
|
return DomainResult(name=DOMAIN_GENERAL, confidence=0.0)
|
||||||
|
|
||||||
|
scores = {domain: 0 for domain in DOMAIN_ORDER}
|
||||||
|
for token in tokens:
|
||||||
|
for domain, keywords in DOMAIN_KEYWORDS.items():
|
||||||
|
if token in keywords:
|
||||||
|
scores[domain] += 2
|
||||||
|
|
||||||
|
for domain, phrases in DOMAIN_PHRASES.items():
|
||||||
|
for phrase in phrases:
|
||||||
|
if phrase in normalized:
|
||||||
|
scores[domain] += 2
|
||||||
|
|
||||||
|
if any(token in GREETING_TOKENS for token in tokens):
|
||||||
|
scores[DOMAIN_PERSONAL_NAMES] += 1
|
||||||
|
|
||||||
|
# Boost domains from configured dictionary terms and replacement targets.
|
||||||
|
dictionary_tokens = self._dictionary_tokens()
|
||||||
|
for token in dictionary_tokens:
|
||||||
|
for domain, keywords in DOMAIN_KEYWORDS.items():
|
||||||
|
if token in keywords and token in tokens:
|
||||||
|
scores[domain] += 1
|
||||||
|
|
||||||
|
top_domain = DOMAIN_GENERAL
|
||||||
|
top_score = 0
|
||||||
|
total_score = 0
|
||||||
|
for domain in DOMAIN_ORDER:
|
||||||
|
score = scores[domain]
|
||||||
|
total_score += score
|
||||||
|
if score > top_score:
|
||||||
|
top_score = score
|
||||||
|
top_domain = domain
|
||||||
|
|
||||||
|
if top_score < 2 or total_score == 0:
|
||||||
|
return DomainResult(name=DOMAIN_GENERAL, confidence=0.0)
|
||||||
|
|
||||||
|
confidence = top_score / total_score
|
||||||
|
if confidence < 0.45:
|
||||||
|
return DomainResult(name=DOMAIN_GENERAL, confidence=0.0)
|
||||||
|
|
||||||
|
return DomainResult(name=top_domain, confidence=round(confidence, 2))
|
||||||
|
|
||||||
|
def _build_stt_hotwords(self, *, limit: int, char_budget: int) -> str:
|
||||||
|
items = _dedupe_preserve_order(
|
||||||
|
[rule.target for rule in self._replacements] + self._terms
|
||||||
|
)
|
||||||
|
words: list[str] = []
|
||||||
|
used = 0
|
||||||
|
for item in items:
|
||||||
|
if len(words) >= limit:
|
||||||
|
break
|
||||||
|
addition = len(item) + (2 if words else 0)
|
||||||
|
if used + addition > char_budget:
|
||||||
|
break
|
||||||
|
words.append(item)
|
||||||
|
used += addition
|
||||||
|
return ", ".join(words)
|
||||||
|
|
||||||
|
def _build_stt_initial_prompt(self, *, char_budget: int) -> str:
|
||||||
|
if not self._stt_hotwords:
|
||||||
|
return ""
|
||||||
|
prefix = "Preferred vocabulary: "
|
||||||
|
available = max(char_budget - len(prefix), 0)
|
||||||
|
hotwords = self._stt_hotwords[:available].rstrip(", ")
|
||||||
|
if not hotwords:
|
||||||
|
return ""
|
||||||
|
return prefix + hotwords
|
||||||
|
|
||||||
|
def _dictionary_tokens(self) -> set[str]:
|
||||||
|
values: list[str] = []
|
||||||
|
for rule in self._replacements:
|
||||||
|
values.append(rule.source)
|
||||||
|
values.append(rule.target)
|
||||||
|
values.extend(self._terms)
|
||||||
|
|
||||||
|
tokens: set[str] = set()
|
||||||
|
for value in values:
|
||||||
|
for token in re.findall(r"[a-z0-9+#./_-]+", value.casefold()):
|
||||||
|
tokens.add(token)
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
|
||||||
|
def _build_replacement_pattern(sources: Iterable[str]) -> re.Pattern[str] | None:
|
||||||
|
unique_sources = _dedupe_preserve_order(list(sources))
|
||||||
|
if not unique_sources:
|
||||||
|
return None
|
||||||
|
|
||||||
|
unique_sources.sort(key=lambda item: (-len(item), item.casefold()))
|
||||||
|
escaped = [re.escape(item) for item in unique_sources]
|
||||||
|
pattern = r"(?<!\w)(" + "|".join(escaped) + r")(?!\w)"
|
||||||
|
return re.compile(pattern, flags=re.IGNORECASE)
|
||||||
|
|
||||||
|
|
||||||
|
def _dedupe_preserve_order(values: list[str]) -> list[str]:
|
||||||
|
out: list[str] = []
|
||||||
|
seen: set[str] = set()
|
||||||
|
for value in values:
|
||||||
|
cleaned = value.strip()
|
||||||
|
if not cleaned:
|
||||||
|
continue
|
||||||
|
key = _normalize_key(cleaned)
|
||||||
|
if key in seen:
|
||||||
|
continue
|
||||||
|
seen.add(key)
|
||||||
|
out.append(cleaned)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_key(value: str) -> str:
|
||||||
|
return " ".join(value.casefold().split())
|
||||||
|
|
@ -27,6 +27,12 @@ class ConfigTests(unittest.TestCase):
|
||||||
self.assertEqual(cfg.injection.backend, "clipboard")
|
self.assertEqual(cfg.injection.backend, "clipboard")
|
||||||
self.assertTrue(cfg.ai.enabled)
|
self.assertTrue(cfg.ai.enabled)
|
||||||
self.assertFalse(cfg.logging.log_transcript)
|
self.assertFalse(cfg.logging.log_transcript)
|
||||||
|
self.assertEqual(cfg.vocabulary.replacements, [])
|
||||||
|
self.assertEqual(cfg.vocabulary.terms, [])
|
||||||
|
self.assertEqual(cfg.vocabulary.max_rules, 500)
|
||||||
|
self.assertEqual(cfg.vocabulary.max_terms, 500)
|
||||||
|
self.assertTrue(cfg.domain_inference.enabled)
|
||||||
|
self.assertEqual(cfg.domain_inference.mode, "auto")
|
||||||
|
|
||||||
def test_loads_nested_config(self):
|
def test_loads_nested_config(self):
|
||||||
payload = {
|
payload = {
|
||||||
|
|
@ -36,6 +42,16 @@ class ConfigTests(unittest.TestCase):
|
||||||
"injection": {"backend": "injection"},
|
"injection": {"backend": "injection"},
|
||||||
"ai": {"enabled": False},
|
"ai": {"enabled": False},
|
||||||
"logging": {"log_transcript": True},
|
"logging": {"log_transcript": True},
|
||||||
|
"vocabulary": {
|
||||||
|
"replacements": [
|
||||||
|
{"from": "Martha", "to": "Marta"},
|
||||||
|
{"from": "docker", "to": "Docker"},
|
||||||
|
],
|
||||||
|
"terms": ["Systemd", "Kubernetes"],
|
||||||
|
"max_rules": 100,
|
||||||
|
"max_terms": 200,
|
||||||
|
},
|
||||||
|
"domain_inference": {"enabled": True, "mode": "auto"},
|
||||||
}
|
}
|
||||||
with tempfile.TemporaryDirectory() as td:
|
with tempfile.TemporaryDirectory() as td:
|
||||||
path = Path(td) / "config.json"
|
path = Path(td) / "config.json"
|
||||||
|
|
@ -50,6 +66,14 @@ class ConfigTests(unittest.TestCase):
|
||||||
self.assertEqual(cfg.injection.backend, "injection")
|
self.assertEqual(cfg.injection.backend, "injection")
|
||||||
self.assertFalse(cfg.ai.enabled)
|
self.assertFalse(cfg.ai.enabled)
|
||||||
self.assertTrue(cfg.logging.log_transcript)
|
self.assertTrue(cfg.logging.log_transcript)
|
||||||
|
self.assertEqual(cfg.vocabulary.max_rules, 100)
|
||||||
|
self.assertEqual(cfg.vocabulary.max_terms, 200)
|
||||||
|
self.assertEqual(len(cfg.vocabulary.replacements), 2)
|
||||||
|
self.assertEqual(cfg.vocabulary.replacements[0].source, "Martha")
|
||||||
|
self.assertEqual(cfg.vocabulary.replacements[0].target, "Marta")
|
||||||
|
self.assertEqual(cfg.vocabulary.terms, ["Systemd", "Kubernetes"])
|
||||||
|
self.assertTrue(cfg.domain_inference.enabled)
|
||||||
|
self.assertEqual(cfg.domain_inference.mode, "auto")
|
||||||
|
|
||||||
def test_loads_legacy_keys(self):
|
def test_loads_legacy_keys(self):
|
||||||
payload = {
|
payload = {
|
||||||
|
|
@ -74,6 +98,7 @@ class ConfigTests(unittest.TestCase):
|
||||||
self.assertEqual(cfg.injection.backend, "clipboard")
|
self.assertEqual(cfg.injection.backend, "clipboard")
|
||||||
self.assertFalse(cfg.ai.enabled)
|
self.assertFalse(cfg.ai.enabled)
|
||||||
self.assertTrue(cfg.logging.log_transcript)
|
self.assertTrue(cfg.logging.log_transcript)
|
||||||
|
self.assertEqual(cfg.vocabulary.replacements, [])
|
||||||
|
|
||||||
def test_invalid_injection_backend_raises(self):
|
def test_invalid_injection_backend_raises(self):
|
||||||
payload = {"injection": {"backend": "invalid"}}
|
payload = {"injection": {"backend": "invalid"}}
|
||||||
|
|
@ -93,6 +118,65 @@ class ConfigTests(unittest.TestCase):
|
||||||
with self.assertRaisesRegex(ValueError, "logging.log_transcript"):
|
with self.assertRaisesRegex(ValueError, "logging.log_transcript"):
|
||||||
load(str(path))
|
load(str(path))
|
||||||
|
|
||||||
|
def test_conflicting_replacements_raise(self):
|
||||||
|
payload = {
|
||||||
|
"vocabulary": {
|
||||||
|
"replacements": [
|
||||||
|
{"from": "Martha", "to": "Marta"},
|
||||||
|
{"from": "martha", "to": "Martha"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
with tempfile.TemporaryDirectory() as td:
|
||||||
|
path = Path(td) / "config.json"
|
||||||
|
path.write_text(json.dumps(payload), encoding="utf-8")
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(ValueError, "conflicting"):
|
||||||
|
load(str(path))
|
||||||
|
|
||||||
|
def test_duplicate_rules_and_terms_are_deduplicated(self):
|
||||||
|
payload = {
|
||||||
|
"vocabulary": {
|
||||||
|
"replacements": [
|
||||||
|
{"from": "docker", "to": "Docker"},
|
||||||
|
{"from": "DOCKER", "to": "Docker"},
|
||||||
|
],
|
||||||
|
"terms": ["Systemd", "systemd"],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
with tempfile.TemporaryDirectory() as td:
|
||||||
|
path = Path(td) / "config.json"
|
||||||
|
path.write_text(json.dumps(payload), encoding="utf-8")
|
||||||
|
|
||||||
|
cfg = load(str(path))
|
||||||
|
|
||||||
|
self.assertEqual(len(cfg.vocabulary.replacements), 1)
|
||||||
|
self.assertEqual(cfg.vocabulary.replacements[0].source, "docker")
|
||||||
|
self.assertEqual(cfg.vocabulary.replacements[0].target, "Docker")
|
||||||
|
self.assertEqual(cfg.vocabulary.terms, ["Systemd"])
|
||||||
|
|
||||||
|
def test_wildcard_term_raises(self):
|
||||||
|
payload = {
|
||||||
|
"vocabulary": {
|
||||||
|
"terms": ["Dock*"],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
with tempfile.TemporaryDirectory() as td:
|
||||||
|
path = Path(td) / "config.json"
|
||||||
|
path.write_text(json.dumps(payload), encoding="utf-8")
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(ValueError, "wildcard"):
|
||||||
|
load(str(path))
|
||||||
|
|
||||||
|
def test_invalid_domain_mode_raises(self):
|
||||||
|
payload = {"domain_inference": {"mode": "heuristic"}}
|
||||||
|
with tempfile.TemporaryDirectory() as td:
|
||||||
|
path = Path(td) / "config.json"
|
||||||
|
path.write_text(json.dumps(payload), encoding="utf-8")
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(ValueError, "domain_inference.mode"):
|
||||||
|
load(str(path))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ if str(SRC) not in sys.path:
|
||||||
sys.path.insert(0, str(SRC))
|
sys.path.insert(0, str(SRC))
|
||||||
|
|
||||||
import leld
|
import leld
|
||||||
from config import Config
|
from config import Config, VocabularyReplacement
|
||||||
|
|
||||||
|
|
||||||
class FakeDesktop:
|
class FakeDesktop:
|
||||||
|
|
@ -32,8 +32,43 @@ class FakeSegment:
|
||||||
|
|
||||||
|
|
||||||
class FakeModel:
|
class FakeModel:
|
||||||
|
def __init__(self, text: str = "hello world"):
|
||||||
|
self.text = text
|
||||||
|
self.last_kwargs = {}
|
||||||
|
|
||||||
def transcribe(self, _audio, language=None, vad_filter=None):
|
def transcribe(self, _audio, language=None, vad_filter=None):
|
||||||
return [FakeSegment("hello world")], {"language": language, "vad_filter": vad_filter}
|
self.last_kwargs = {
|
||||||
|
"language": language,
|
||||||
|
"vad_filter": vad_filter,
|
||||||
|
}
|
||||||
|
return [FakeSegment(self.text)], self.last_kwargs
|
||||||
|
|
||||||
|
|
||||||
|
class FakeHintModel:
|
||||||
|
def __init__(self, text: str = "hello world"):
|
||||||
|
self.text = text
|
||||||
|
self.last_kwargs = {}
|
||||||
|
|
||||||
|
def transcribe(
|
||||||
|
self,
|
||||||
|
_audio,
|
||||||
|
language=None,
|
||||||
|
vad_filter=None,
|
||||||
|
hotwords=None,
|
||||||
|
initial_prompt=None,
|
||||||
|
):
|
||||||
|
self.last_kwargs = {
|
||||||
|
"language": language,
|
||||||
|
"vad_filter": vad_filter,
|
||||||
|
"hotwords": hotwords,
|
||||||
|
"initial_prompt": initial_prompt,
|
||||||
|
}
|
||||||
|
return [FakeSegment(self.text)], self.last_kwargs
|
||||||
|
|
||||||
|
|
||||||
|
class FakeAIProcessor:
|
||||||
|
def process(self, text, lang="en", **_kwargs):
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
class FakeAudio:
|
class FakeAudio:
|
||||||
|
|
@ -48,12 +83,13 @@ class DaemonTests(unittest.TestCase):
|
||||||
cfg.logging.log_transcript = False
|
cfg.logging.log_transcript = False
|
||||||
return cfg
|
return cfg
|
||||||
|
|
||||||
@patch("leld._build_whisper_model", return_value=FakeModel())
|
|
||||||
@patch("leld.stop_audio_recording", return_value=FakeAudio(8))
|
@patch("leld.stop_audio_recording", return_value=FakeAudio(8))
|
||||||
@patch("leld.start_audio_recording", return_value=(object(), object()))
|
@patch("leld.start_audio_recording", return_value=(object(), object()))
|
||||||
def test_toggle_start_stop_injects_text(self, _start_mock, _stop_mock, _model_mock):
|
def test_toggle_start_stop_injects_text(self, _start_mock, _stop_mock):
|
||||||
desktop = FakeDesktop()
|
desktop = FakeDesktop()
|
||||||
|
with patch("leld._build_whisper_model", return_value=FakeModel()):
|
||||||
daemon = leld.Daemon(self._config(), desktop, verbose=False)
|
daemon = leld.Daemon(self._config(), desktop, verbose=False)
|
||||||
|
daemon.ai_processor = FakeAIProcessor()
|
||||||
daemon._start_stop_worker = (
|
daemon._start_stop_worker = (
|
||||||
lambda stream, record, trigger, process_audio: daemon._stop_and_process(
|
lambda stream, record, trigger, process_audio: daemon._stop_and_process(
|
||||||
stream, record, trigger, process_audio
|
stream, record, trigger, process_audio
|
||||||
|
|
@ -68,12 +104,13 @@ class DaemonTests(unittest.TestCase):
|
||||||
self.assertEqual(daemon.get_state(), leld.State.IDLE)
|
self.assertEqual(daemon.get_state(), leld.State.IDLE)
|
||||||
self.assertEqual(desktop.inject_calls, [("hello world", "clipboard")])
|
self.assertEqual(desktop.inject_calls, [("hello world", "clipboard")])
|
||||||
|
|
||||||
@patch("leld._build_whisper_model", return_value=FakeModel())
|
|
||||||
@patch("leld.stop_audio_recording", return_value=FakeAudio(8))
|
@patch("leld.stop_audio_recording", return_value=FakeAudio(8))
|
||||||
@patch("leld.start_audio_recording", return_value=(object(), object()))
|
@patch("leld.start_audio_recording", return_value=(object(), object()))
|
||||||
def test_shutdown_stops_recording_without_injection(self, _start_mock, _stop_mock, _model_mock):
|
def test_shutdown_stops_recording_without_injection(self, _start_mock, _stop_mock):
|
||||||
desktop = FakeDesktop()
|
desktop = FakeDesktop()
|
||||||
|
with patch("leld._build_whisper_model", return_value=FakeModel()):
|
||||||
daemon = leld.Daemon(self._config(), desktop, verbose=False)
|
daemon = leld.Daemon(self._config(), desktop, verbose=False)
|
||||||
|
daemon.ai_processor = FakeAIProcessor()
|
||||||
daemon._start_stop_worker = (
|
daemon._start_stop_worker = (
|
||||||
lambda stream, record, trigger, process_audio: daemon._stop_and_process(
|
lambda stream, record, trigger, process_audio: daemon._stop_and_process(
|
||||||
stream, record, trigger, process_audio
|
stream, record, trigger, process_audio
|
||||||
|
|
@ -87,6 +124,60 @@ class DaemonTests(unittest.TestCase):
|
||||||
self.assertEqual(daemon.get_state(), leld.State.IDLE)
|
self.assertEqual(daemon.get_state(), leld.State.IDLE)
|
||||||
self.assertEqual(desktop.inject_calls, [])
|
self.assertEqual(desktop.inject_calls, [])
|
||||||
|
|
||||||
|
@patch("leld.stop_audio_recording", return_value=FakeAudio(8))
|
||||||
|
@patch("leld.start_audio_recording", return_value=(object(), object()))
|
||||||
|
def test_dictionary_replacement_applies_after_ai(self, _start_mock, _stop_mock):
|
||||||
|
desktop = FakeDesktop()
|
||||||
|
model = FakeModel(text="good morning martha")
|
||||||
|
cfg = self._config()
|
||||||
|
cfg.vocabulary.replacements = [VocabularyReplacement(source="Martha", target="Marta")]
|
||||||
|
|
||||||
|
with patch("leld._build_whisper_model", return_value=model):
|
||||||
|
daemon = leld.Daemon(cfg, desktop, verbose=False)
|
||||||
|
daemon.ai_processor = FakeAIProcessor()
|
||||||
|
daemon._start_stop_worker = (
|
||||||
|
lambda stream, record, trigger, process_audio: daemon._stop_and_process(
|
||||||
|
stream, record, trigger, process_audio
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
daemon.toggle()
|
||||||
|
daemon.toggle()
|
||||||
|
|
||||||
|
self.assertEqual(desktop.inject_calls, [("good morning Marta", "clipboard")])
|
||||||
|
|
||||||
|
def test_transcribe_skips_hints_when_model_does_not_support_them(self):
|
||||||
|
desktop = FakeDesktop()
|
||||||
|
model = FakeModel(text="hello")
|
||||||
|
cfg = self._config()
|
||||||
|
cfg.vocabulary.terms = ["Docker", "Systemd"]
|
||||||
|
|
||||||
|
with patch("leld._build_whisper_model", return_value=model):
|
||||||
|
daemon = leld.Daemon(cfg, desktop, verbose=False)
|
||||||
|
|
||||||
|
result = daemon._transcribe(object())
|
||||||
|
|
||||||
|
self.assertEqual(result, "hello")
|
||||||
|
self.assertNotIn("hotwords", model.last_kwargs)
|
||||||
|
self.assertNotIn("initial_prompt", model.last_kwargs)
|
||||||
|
|
||||||
|
def test_transcribe_applies_hints_when_model_supports_them(self):
|
||||||
|
desktop = FakeDesktop()
|
||||||
|
model = FakeHintModel(text="hello")
|
||||||
|
cfg = self._config()
|
||||||
|
cfg.vocabulary.terms = ["Systemd"]
|
||||||
|
cfg.vocabulary.replacements = [VocabularyReplacement(source="docker", target="Docker")]
|
||||||
|
|
||||||
|
with patch("leld._build_whisper_model", return_value=model):
|
||||||
|
daemon = leld.Daemon(cfg, desktop, verbose=False)
|
||||||
|
|
||||||
|
result = daemon._transcribe(object())
|
||||||
|
|
||||||
|
self.assertEqual(result, "hello")
|
||||||
|
self.assertIn("Docker", model.last_kwargs["hotwords"])
|
||||||
|
self.assertIn("Systemd", model.last_kwargs["hotwords"])
|
||||||
|
self.assertIn("Preferred vocabulary", model.last_kwargs["initial_prompt"])
|
||||||
|
|
||||||
|
|
||||||
class LockTests(unittest.TestCase):
|
class LockTests(unittest.TestCase):
|
||||||
def test_lock_rejects_second_instance(self):
|
def test_lock_rejects_second_instance(self):
|
||||||
|
|
|
||||||
76
tests/test_vocabulary.py
Normal file
76
tests/test_vocabulary.py
Normal file
|
|
@ -0,0 +1,76 @@
|
||||||
|
import sys
|
||||||
|
import unittest
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
ROOT = Path(__file__).resolve().parents[1]
|
||||||
|
SRC = ROOT / "src"
|
||||||
|
if str(SRC) not in sys.path:
|
||||||
|
sys.path.insert(0, str(SRC))
|
||||||
|
|
||||||
|
from config import DomainInferenceConfig, VocabularyConfig, VocabularyReplacement
|
||||||
|
from vocabulary import DOMAIN_GENERAL, VocabularyEngine
|
||||||
|
|
||||||
|
|
||||||
|
class VocabularyEngineTests(unittest.TestCase):
|
||||||
|
def _engine(self, replacements=None, terms=None, domain_enabled=True):
|
||||||
|
vocab = VocabularyConfig(
|
||||||
|
replacements=replacements or [],
|
||||||
|
terms=terms or [],
|
||||||
|
)
|
||||||
|
domain = DomainInferenceConfig(enabled=domain_enabled, mode="auto")
|
||||||
|
return VocabularyEngine(vocab, domain)
|
||||||
|
|
||||||
|
def test_boundary_aware_replacement(self):
|
||||||
|
engine = self._engine(
|
||||||
|
replacements=[VocabularyReplacement(source="Martha", target="Marta")],
|
||||||
|
)
|
||||||
|
|
||||||
|
text = "Martha met Marthaville and Martha."
|
||||||
|
out = engine.apply_deterministic_replacements(text)
|
||||||
|
|
||||||
|
self.assertEqual(out, "Marta met Marthaville and Marta.")
|
||||||
|
|
||||||
|
def test_longest_match_replacement_wins(self):
|
||||||
|
engine = self._engine(
|
||||||
|
replacements=[
|
||||||
|
VocabularyReplacement(source="new york", target="NYC"),
|
||||||
|
VocabularyReplacement(source="york", target="Yorkshire"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
out = engine.apply_deterministic_replacements("new york york")
|
||||||
|
self.assertEqual(out, "NYC Yorkshire")
|
||||||
|
|
||||||
|
def test_stt_hints_are_bounded(self):
|
||||||
|
terms = [f"term{i}" for i in range(300)]
|
||||||
|
engine = self._engine(terms=terms)
|
||||||
|
|
||||||
|
hotwords, prompt = engine.build_stt_hints()
|
||||||
|
|
||||||
|
self.assertLessEqual(len(hotwords), 1024)
|
||||||
|
self.assertLessEqual(len(prompt), 600)
|
||||||
|
|
||||||
|
def test_domain_inference_general_fallback(self):
|
||||||
|
engine = self._engine()
|
||||||
|
result = engine.infer_domain("please call me later")
|
||||||
|
|
||||||
|
self.assertEqual(result.name, DOMAIN_GENERAL)
|
||||||
|
self.assertEqual(result.confidence, 0.0)
|
||||||
|
|
||||||
|
def test_domain_inference_for_technical_text(self):
|
||||||
|
engine = self._engine(terms=["Docker", "Systemd"])
|
||||||
|
result = engine.infer_domain("restart Docker and systemd service on prod")
|
||||||
|
|
||||||
|
self.assertNotEqual(result.name, DOMAIN_GENERAL)
|
||||||
|
self.assertGreater(result.confidence, 0.0)
|
||||||
|
|
||||||
|
def test_domain_inference_can_be_disabled(self):
|
||||||
|
engine = self._engine(domain_enabled=False)
|
||||||
|
result = engine.infer_domain("please restart docker")
|
||||||
|
|
||||||
|
self.assertEqual(result.name, DOMAIN_GENERAL)
|
||||||
|
self.assertEqual(result.confidence, 0.0)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Loading…
Add table
Add a link
Reference in a new issue