Add vocabulary correction pipeline and example config

This commit is contained in:
Thales Maciel 2026-02-25 10:03:32 -03:00
parent f9224621fa
commit c3503fbbde
9 changed files with 865 additions and 23 deletions

View file

@ -27,6 +27,12 @@ class ConfigTests(unittest.TestCase):
self.assertEqual(cfg.injection.backend, "clipboard")
self.assertTrue(cfg.ai.enabled)
self.assertFalse(cfg.logging.log_transcript)
self.assertEqual(cfg.vocabulary.replacements, [])
self.assertEqual(cfg.vocabulary.terms, [])
self.assertEqual(cfg.vocabulary.max_rules, 500)
self.assertEqual(cfg.vocabulary.max_terms, 500)
self.assertTrue(cfg.domain_inference.enabled)
self.assertEqual(cfg.domain_inference.mode, "auto")
def test_loads_nested_config(self):
payload = {
@ -36,6 +42,16 @@ class ConfigTests(unittest.TestCase):
"injection": {"backend": "injection"},
"ai": {"enabled": False},
"logging": {"log_transcript": True},
"vocabulary": {
"replacements": [
{"from": "Martha", "to": "Marta"},
{"from": "docker", "to": "Docker"},
],
"terms": ["Systemd", "Kubernetes"],
"max_rules": 100,
"max_terms": 200,
},
"domain_inference": {"enabled": True, "mode": "auto"},
}
with tempfile.TemporaryDirectory() as td:
path = Path(td) / "config.json"
@ -50,6 +66,14 @@ class ConfigTests(unittest.TestCase):
self.assertEqual(cfg.injection.backend, "injection")
self.assertFalse(cfg.ai.enabled)
self.assertTrue(cfg.logging.log_transcript)
self.assertEqual(cfg.vocabulary.max_rules, 100)
self.assertEqual(cfg.vocabulary.max_terms, 200)
self.assertEqual(len(cfg.vocabulary.replacements), 2)
self.assertEqual(cfg.vocabulary.replacements[0].source, "Martha")
self.assertEqual(cfg.vocabulary.replacements[0].target, "Marta")
self.assertEqual(cfg.vocabulary.terms, ["Systemd", "Kubernetes"])
self.assertTrue(cfg.domain_inference.enabled)
self.assertEqual(cfg.domain_inference.mode, "auto")
def test_loads_legacy_keys(self):
payload = {
@ -74,6 +98,7 @@ class ConfigTests(unittest.TestCase):
self.assertEqual(cfg.injection.backend, "clipboard")
self.assertFalse(cfg.ai.enabled)
self.assertTrue(cfg.logging.log_transcript)
self.assertEqual(cfg.vocabulary.replacements, [])
def test_invalid_injection_backend_raises(self):
payload = {"injection": {"backend": "invalid"}}
@ -93,6 +118,65 @@ class ConfigTests(unittest.TestCase):
with self.assertRaisesRegex(ValueError, "logging.log_transcript"):
load(str(path))
def test_conflicting_replacements_raise(self):
payload = {
"vocabulary": {
"replacements": [
{"from": "Martha", "to": "Marta"},
{"from": "martha", "to": "Martha"},
]
}
}
with tempfile.TemporaryDirectory() as td:
path = Path(td) / "config.json"
path.write_text(json.dumps(payload), encoding="utf-8")
with self.assertRaisesRegex(ValueError, "conflicting"):
load(str(path))
def test_duplicate_rules_and_terms_are_deduplicated(self):
payload = {
"vocabulary": {
"replacements": [
{"from": "docker", "to": "Docker"},
{"from": "DOCKER", "to": "Docker"},
],
"terms": ["Systemd", "systemd"],
}
}
with tempfile.TemporaryDirectory() as td:
path = Path(td) / "config.json"
path.write_text(json.dumps(payload), encoding="utf-8")
cfg = load(str(path))
self.assertEqual(len(cfg.vocabulary.replacements), 1)
self.assertEqual(cfg.vocabulary.replacements[0].source, "docker")
self.assertEqual(cfg.vocabulary.replacements[0].target, "Docker")
self.assertEqual(cfg.vocabulary.terms, ["Systemd"])
def test_wildcard_term_raises(self):
payload = {
"vocabulary": {
"terms": ["Dock*"],
}
}
with tempfile.TemporaryDirectory() as td:
path = Path(td) / "config.json"
path.write_text(json.dumps(payload), encoding="utf-8")
with self.assertRaisesRegex(ValueError, "wildcard"):
load(str(path))
def test_invalid_domain_mode_raises(self):
payload = {"domain_inference": {"mode": "heuristic"}}
with tempfile.TemporaryDirectory() as td:
path = Path(td) / "config.json"
path.write_text(json.dumps(payload), encoding="utf-8")
with self.assertRaisesRegex(ValueError, "domain_inference.mode"):
load(str(path))
if __name__ == "__main__":
unittest.main()

View file

@ -11,7 +11,7 @@ if str(SRC) not in sys.path:
sys.path.insert(0, str(SRC))
import leld
from config import Config
from config import Config, VocabularyReplacement
class FakeDesktop:
@ -32,8 +32,43 @@ class FakeSegment:
class FakeModel:
def __init__(self, text: str = "hello world"):
self.text = text
self.last_kwargs = {}
def transcribe(self, _audio, language=None, vad_filter=None):
return [FakeSegment("hello world")], {"language": language, "vad_filter": vad_filter}
self.last_kwargs = {
"language": language,
"vad_filter": vad_filter,
}
return [FakeSegment(self.text)], self.last_kwargs
class FakeHintModel:
def __init__(self, text: str = "hello world"):
self.text = text
self.last_kwargs = {}
def transcribe(
self,
_audio,
language=None,
vad_filter=None,
hotwords=None,
initial_prompt=None,
):
self.last_kwargs = {
"language": language,
"vad_filter": vad_filter,
"hotwords": hotwords,
"initial_prompt": initial_prompt,
}
return [FakeSegment(self.text)], self.last_kwargs
class FakeAIProcessor:
def process(self, text, lang="en", **_kwargs):
return text
class FakeAudio:
@ -48,12 +83,13 @@ class DaemonTests(unittest.TestCase):
cfg.logging.log_transcript = False
return cfg
@patch("leld._build_whisper_model", return_value=FakeModel())
@patch("leld.stop_audio_recording", return_value=FakeAudio(8))
@patch("leld.start_audio_recording", return_value=(object(), object()))
def test_toggle_start_stop_injects_text(self, _start_mock, _stop_mock, _model_mock):
def test_toggle_start_stop_injects_text(self, _start_mock, _stop_mock):
desktop = FakeDesktop()
daemon = leld.Daemon(self._config(), desktop, verbose=False)
with patch("leld._build_whisper_model", return_value=FakeModel()):
daemon = leld.Daemon(self._config(), desktop, verbose=False)
daemon.ai_processor = FakeAIProcessor()
daemon._start_stop_worker = (
lambda stream, record, trigger, process_audio: daemon._stop_and_process(
stream, record, trigger, process_audio
@ -68,12 +104,13 @@ class DaemonTests(unittest.TestCase):
self.assertEqual(daemon.get_state(), leld.State.IDLE)
self.assertEqual(desktop.inject_calls, [("hello world", "clipboard")])
@patch("leld._build_whisper_model", return_value=FakeModel())
@patch("leld.stop_audio_recording", return_value=FakeAudio(8))
@patch("leld.start_audio_recording", return_value=(object(), object()))
def test_shutdown_stops_recording_without_injection(self, _start_mock, _stop_mock, _model_mock):
def test_shutdown_stops_recording_without_injection(self, _start_mock, _stop_mock):
desktop = FakeDesktop()
daemon = leld.Daemon(self._config(), desktop, verbose=False)
with patch("leld._build_whisper_model", return_value=FakeModel()):
daemon = leld.Daemon(self._config(), desktop, verbose=False)
daemon.ai_processor = FakeAIProcessor()
daemon._start_stop_worker = (
lambda stream, record, trigger, process_audio: daemon._stop_and_process(
stream, record, trigger, process_audio
@ -87,6 +124,60 @@ class DaemonTests(unittest.TestCase):
self.assertEqual(daemon.get_state(), leld.State.IDLE)
self.assertEqual(desktop.inject_calls, [])
@patch("leld.stop_audio_recording", return_value=FakeAudio(8))
@patch("leld.start_audio_recording", return_value=(object(), object()))
def test_dictionary_replacement_applies_after_ai(self, _start_mock, _stop_mock):
desktop = FakeDesktop()
model = FakeModel(text="good morning martha")
cfg = self._config()
cfg.vocabulary.replacements = [VocabularyReplacement(source="Martha", target="Marta")]
with patch("leld._build_whisper_model", return_value=model):
daemon = leld.Daemon(cfg, desktop, verbose=False)
daemon.ai_processor = FakeAIProcessor()
daemon._start_stop_worker = (
lambda stream, record, trigger, process_audio: daemon._stop_and_process(
stream, record, trigger, process_audio
)
)
daemon.toggle()
daemon.toggle()
self.assertEqual(desktop.inject_calls, [("good morning Marta", "clipboard")])
def test_transcribe_skips_hints_when_model_does_not_support_them(self):
desktop = FakeDesktop()
model = FakeModel(text="hello")
cfg = self._config()
cfg.vocabulary.terms = ["Docker", "Systemd"]
with patch("leld._build_whisper_model", return_value=model):
daemon = leld.Daemon(cfg, desktop, verbose=False)
result = daemon._transcribe(object())
self.assertEqual(result, "hello")
self.assertNotIn("hotwords", model.last_kwargs)
self.assertNotIn("initial_prompt", model.last_kwargs)
def test_transcribe_applies_hints_when_model_supports_them(self):
desktop = FakeDesktop()
model = FakeHintModel(text="hello")
cfg = self._config()
cfg.vocabulary.terms = ["Systemd"]
cfg.vocabulary.replacements = [VocabularyReplacement(source="docker", target="Docker")]
with patch("leld._build_whisper_model", return_value=model):
daemon = leld.Daemon(cfg, desktop, verbose=False)
result = daemon._transcribe(object())
self.assertEqual(result, "hello")
self.assertIn("Docker", model.last_kwargs["hotwords"])
self.assertIn("Systemd", model.last_kwargs["hotwords"])
self.assertIn("Preferred vocabulary", model.last_kwargs["initial_prompt"])
class LockTests(unittest.TestCase):
def test_lock_rejects_second_instance(self):

76
tests/test_vocabulary.py Normal file
View file

@ -0,0 +1,76 @@
import sys
import unittest
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
SRC = ROOT / "src"
if str(SRC) not in sys.path:
sys.path.insert(0, str(SRC))
from config import DomainInferenceConfig, VocabularyConfig, VocabularyReplacement
from vocabulary import DOMAIN_GENERAL, VocabularyEngine
class VocabularyEngineTests(unittest.TestCase):
def _engine(self, replacements=None, terms=None, domain_enabled=True):
vocab = VocabularyConfig(
replacements=replacements or [],
terms=terms or [],
)
domain = DomainInferenceConfig(enabled=domain_enabled, mode="auto")
return VocabularyEngine(vocab, domain)
def test_boundary_aware_replacement(self):
engine = self._engine(
replacements=[VocabularyReplacement(source="Martha", target="Marta")],
)
text = "Martha met Marthaville and Martha."
out = engine.apply_deterministic_replacements(text)
self.assertEqual(out, "Marta met Marthaville and Marta.")
def test_longest_match_replacement_wins(self):
engine = self._engine(
replacements=[
VocabularyReplacement(source="new york", target="NYC"),
VocabularyReplacement(source="york", target="Yorkshire"),
],
)
out = engine.apply_deterministic_replacements("new york york")
self.assertEqual(out, "NYC Yorkshire")
def test_stt_hints_are_bounded(self):
terms = [f"term{i}" for i in range(300)]
engine = self._engine(terms=terms)
hotwords, prompt = engine.build_stt_hints()
self.assertLessEqual(len(hotwords), 1024)
self.assertLessEqual(len(prompt), 600)
def test_domain_inference_general_fallback(self):
engine = self._engine()
result = engine.infer_domain("please call me later")
self.assertEqual(result.name, DOMAIN_GENERAL)
self.assertEqual(result.confidence, 0.0)
def test_domain_inference_for_technical_text(self):
engine = self._engine(terms=["Docker", "Systemd"])
result = engine.infer_domain("restart Docker and systemd service on prod")
self.assertNotEqual(result.name, DOMAIN_GENERAL)
self.assertGreater(result.confidence, 0.0)
def test_domain_inference_can_be_disabled(self):
engine = self._engine(domain_enabled=False)
result = engine.infer_domain("please restart docker")
self.assertEqual(result.name, DOMAIN_GENERAL)
self.assertEqual(result.confidence, 0.0)
if __name__ == "__main__":
unittest.main()