Add vocabulary correction pipeline and example config
This commit is contained in:
parent
f9224621fa
commit
c3503fbbde
9 changed files with 865 additions and 23 deletions
|
|
@ -27,6 +27,12 @@ class ConfigTests(unittest.TestCase):
|
|||
self.assertEqual(cfg.injection.backend, "clipboard")
|
||||
self.assertTrue(cfg.ai.enabled)
|
||||
self.assertFalse(cfg.logging.log_transcript)
|
||||
self.assertEqual(cfg.vocabulary.replacements, [])
|
||||
self.assertEqual(cfg.vocabulary.terms, [])
|
||||
self.assertEqual(cfg.vocabulary.max_rules, 500)
|
||||
self.assertEqual(cfg.vocabulary.max_terms, 500)
|
||||
self.assertTrue(cfg.domain_inference.enabled)
|
||||
self.assertEqual(cfg.domain_inference.mode, "auto")
|
||||
|
||||
def test_loads_nested_config(self):
|
||||
payload = {
|
||||
|
|
@ -36,6 +42,16 @@ class ConfigTests(unittest.TestCase):
|
|||
"injection": {"backend": "injection"},
|
||||
"ai": {"enabled": False},
|
||||
"logging": {"log_transcript": True},
|
||||
"vocabulary": {
|
||||
"replacements": [
|
||||
{"from": "Martha", "to": "Marta"},
|
||||
{"from": "docker", "to": "Docker"},
|
||||
],
|
||||
"terms": ["Systemd", "Kubernetes"],
|
||||
"max_rules": 100,
|
||||
"max_terms": 200,
|
||||
},
|
||||
"domain_inference": {"enabled": True, "mode": "auto"},
|
||||
}
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
path = Path(td) / "config.json"
|
||||
|
|
@ -50,6 +66,14 @@ class ConfigTests(unittest.TestCase):
|
|||
self.assertEqual(cfg.injection.backend, "injection")
|
||||
self.assertFalse(cfg.ai.enabled)
|
||||
self.assertTrue(cfg.logging.log_transcript)
|
||||
self.assertEqual(cfg.vocabulary.max_rules, 100)
|
||||
self.assertEqual(cfg.vocabulary.max_terms, 200)
|
||||
self.assertEqual(len(cfg.vocabulary.replacements), 2)
|
||||
self.assertEqual(cfg.vocabulary.replacements[0].source, "Martha")
|
||||
self.assertEqual(cfg.vocabulary.replacements[0].target, "Marta")
|
||||
self.assertEqual(cfg.vocabulary.terms, ["Systemd", "Kubernetes"])
|
||||
self.assertTrue(cfg.domain_inference.enabled)
|
||||
self.assertEqual(cfg.domain_inference.mode, "auto")
|
||||
|
||||
def test_loads_legacy_keys(self):
|
||||
payload = {
|
||||
|
|
@ -74,6 +98,7 @@ class ConfigTests(unittest.TestCase):
|
|||
self.assertEqual(cfg.injection.backend, "clipboard")
|
||||
self.assertFalse(cfg.ai.enabled)
|
||||
self.assertTrue(cfg.logging.log_transcript)
|
||||
self.assertEqual(cfg.vocabulary.replacements, [])
|
||||
|
||||
def test_invalid_injection_backend_raises(self):
|
||||
payload = {"injection": {"backend": "invalid"}}
|
||||
|
|
@ -93,6 +118,65 @@ class ConfigTests(unittest.TestCase):
|
|||
with self.assertRaisesRegex(ValueError, "logging.log_transcript"):
|
||||
load(str(path))
|
||||
|
||||
def test_conflicting_replacements_raise(self):
|
||||
payload = {
|
||||
"vocabulary": {
|
||||
"replacements": [
|
||||
{"from": "Martha", "to": "Marta"},
|
||||
{"from": "martha", "to": "Martha"},
|
||||
]
|
||||
}
|
||||
}
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
path = Path(td) / "config.json"
|
||||
path.write_text(json.dumps(payload), encoding="utf-8")
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "conflicting"):
|
||||
load(str(path))
|
||||
|
||||
def test_duplicate_rules_and_terms_are_deduplicated(self):
|
||||
payload = {
|
||||
"vocabulary": {
|
||||
"replacements": [
|
||||
{"from": "docker", "to": "Docker"},
|
||||
{"from": "DOCKER", "to": "Docker"},
|
||||
],
|
||||
"terms": ["Systemd", "systemd"],
|
||||
}
|
||||
}
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
path = Path(td) / "config.json"
|
||||
path.write_text(json.dumps(payload), encoding="utf-8")
|
||||
|
||||
cfg = load(str(path))
|
||||
|
||||
self.assertEqual(len(cfg.vocabulary.replacements), 1)
|
||||
self.assertEqual(cfg.vocabulary.replacements[0].source, "docker")
|
||||
self.assertEqual(cfg.vocabulary.replacements[0].target, "Docker")
|
||||
self.assertEqual(cfg.vocabulary.terms, ["Systemd"])
|
||||
|
||||
def test_wildcard_term_raises(self):
|
||||
payload = {
|
||||
"vocabulary": {
|
||||
"terms": ["Dock*"],
|
||||
}
|
||||
}
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
path = Path(td) / "config.json"
|
||||
path.write_text(json.dumps(payload), encoding="utf-8")
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "wildcard"):
|
||||
load(str(path))
|
||||
|
||||
def test_invalid_domain_mode_raises(self):
|
||||
payload = {"domain_inference": {"mode": "heuristic"}}
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
path = Path(td) / "config.json"
|
||||
path.write_text(json.dumps(payload), encoding="utf-8")
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "domain_inference.mode"):
|
||||
load(str(path))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ if str(SRC) not in sys.path:
|
|||
sys.path.insert(0, str(SRC))
|
||||
|
||||
import leld
|
||||
from config import Config
|
||||
from config import Config, VocabularyReplacement
|
||||
|
||||
|
||||
class FakeDesktop:
|
||||
|
|
@ -32,8 +32,43 @@ class FakeSegment:
|
|||
|
||||
|
||||
class FakeModel:
|
||||
def __init__(self, text: str = "hello world"):
|
||||
self.text = text
|
||||
self.last_kwargs = {}
|
||||
|
||||
def transcribe(self, _audio, language=None, vad_filter=None):
|
||||
return [FakeSegment("hello world")], {"language": language, "vad_filter": vad_filter}
|
||||
self.last_kwargs = {
|
||||
"language": language,
|
||||
"vad_filter": vad_filter,
|
||||
}
|
||||
return [FakeSegment(self.text)], self.last_kwargs
|
||||
|
||||
|
||||
class FakeHintModel:
|
||||
def __init__(self, text: str = "hello world"):
|
||||
self.text = text
|
||||
self.last_kwargs = {}
|
||||
|
||||
def transcribe(
|
||||
self,
|
||||
_audio,
|
||||
language=None,
|
||||
vad_filter=None,
|
||||
hotwords=None,
|
||||
initial_prompt=None,
|
||||
):
|
||||
self.last_kwargs = {
|
||||
"language": language,
|
||||
"vad_filter": vad_filter,
|
||||
"hotwords": hotwords,
|
||||
"initial_prompt": initial_prompt,
|
||||
}
|
||||
return [FakeSegment(self.text)], self.last_kwargs
|
||||
|
||||
|
||||
class FakeAIProcessor:
|
||||
def process(self, text, lang="en", **_kwargs):
|
||||
return text
|
||||
|
||||
|
||||
class FakeAudio:
|
||||
|
|
@ -48,12 +83,13 @@ class DaemonTests(unittest.TestCase):
|
|||
cfg.logging.log_transcript = False
|
||||
return cfg
|
||||
|
||||
@patch("leld._build_whisper_model", return_value=FakeModel())
|
||||
@patch("leld.stop_audio_recording", return_value=FakeAudio(8))
|
||||
@patch("leld.start_audio_recording", return_value=(object(), object()))
|
||||
def test_toggle_start_stop_injects_text(self, _start_mock, _stop_mock, _model_mock):
|
||||
def test_toggle_start_stop_injects_text(self, _start_mock, _stop_mock):
|
||||
desktop = FakeDesktop()
|
||||
daemon = leld.Daemon(self._config(), desktop, verbose=False)
|
||||
with patch("leld._build_whisper_model", return_value=FakeModel()):
|
||||
daemon = leld.Daemon(self._config(), desktop, verbose=False)
|
||||
daemon.ai_processor = FakeAIProcessor()
|
||||
daemon._start_stop_worker = (
|
||||
lambda stream, record, trigger, process_audio: daemon._stop_and_process(
|
||||
stream, record, trigger, process_audio
|
||||
|
|
@ -68,12 +104,13 @@ class DaemonTests(unittest.TestCase):
|
|||
self.assertEqual(daemon.get_state(), leld.State.IDLE)
|
||||
self.assertEqual(desktop.inject_calls, [("hello world", "clipboard")])
|
||||
|
||||
@patch("leld._build_whisper_model", return_value=FakeModel())
|
||||
@patch("leld.stop_audio_recording", return_value=FakeAudio(8))
|
||||
@patch("leld.start_audio_recording", return_value=(object(), object()))
|
||||
def test_shutdown_stops_recording_without_injection(self, _start_mock, _stop_mock, _model_mock):
|
||||
def test_shutdown_stops_recording_without_injection(self, _start_mock, _stop_mock):
|
||||
desktop = FakeDesktop()
|
||||
daemon = leld.Daemon(self._config(), desktop, verbose=False)
|
||||
with patch("leld._build_whisper_model", return_value=FakeModel()):
|
||||
daemon = leld.Daemon(self._config(), desktop, verbose=False)
|
||||
daemon.ai_processor = FakeAIProcessor()
|
||||
daemon._start_stop_worker = (
|
||||
lambda stream, record, trigger, process_audio: daemon._stop_and_process(
|
||||
stream, record, trigger, process_audio
|
||||
|
|
@ -87,6 +124,60 @@ class DaemonTests(unittest.TestCase):
|
|||
self.assertEqual(daemon.get_state(), leld.State.IDLE)
|
||||
self.assertEqual(desktop.inject_calls, [])
|
||||
|
||||
@patch("leld.stop_audio_recording", return_value=FakeAudio(8))
|
||||
@patch("leld.start_audio_recording", return_value=(object(), object()))
|
||||
def test_dictionary_replacement_applies_after_ai(self, _start_mock, _stop_mock):
|
||||
desktop = FakeDesktop()
|
||||
model = FakeModel(text="good morning martha")
|
||||
cfg = self._config()
|
||||
cfg.vocabulary.replacements = [VocabularyReplacement(source="Martha", target="Marta")]
|
||||
|
||||
with patch("leld._build_whisper_model", return_value=model):
|
||||
daemon = leld.Daemon(cfg, desktop, verbose=False)
|
||||
daemon.ai_processor = FakeAIProcessor()
|
||||
daemon._start_stop_worker = (
|
||||
lambda stream, record, trigger, process_audio: daemon._stop_and_process(
|
||||
stream, record, trigger, process_audio
|
||||
)
|
||||
)
|
||||
|
||||
daemon.toggle()
|
||||
daemon.toggle()
|
||||
|
||||
self.assertEqual(desktop.inject_calls, [("good morning Marta", "clipboard")])
|
||||
|
||||
def test_transcribe_skips_hints_when_model_does_not_support_them(self):
|
||||
desktop = FakeDesktop()
|
||||
model = FakeModel(text="hello")
|
||||
cfg = self._config()
|
||||
cfg.vocabulary.terms = ["Docker", "Systemd"]
|
||||
|
||||
with patch("leld._build_whisper_model", return_value=model):
|
||||
daemon = leld.Daemon(cfg, desktop, verbose=False)
|
||||
|
||||
result = daemon._transcribe(object())
|
||||
|
||||
self.assertEqual(result, "hello")
|
||||
self.assertNotIn("hotwords", model.last_kwargs)
|
||||
self.assertNotIn("initial_prompt", model.last_kwargs)
|
||||
|
||||
def test_transcribe_applies_hints_when_model_supports_them(self):
|
||||
desktop = FakeDesktop()
|
||||
model = FakeHintModel(text="hello")
|
||||
cfg = self._config()
|
||||
cfg.vocabulary.terms = ["Systemd"]
|
||||
cfg.vocabulary.replacements = [VocabularyReplacement(source="docker", target="Docker")]
|
||||
|
||||
with patch("leld._build_whisper_model", return_value=model):
|
||||
daemon = leld.Daemon(cfg, desktop, verbose=False)
|
||||
|
||||
result = daemon._transcribe(object())
|
||||
|
||||
self.assertEqual(result, "hello")
|
||||
self.assertIn("Docker", model.last_kwargs["hotwords"])
|
||||
self.assertIn("Systemd", model.last_kwargs["hotwords"])
|
||||
self.assertIn("Preferred vocabulary", model.last_kwargs["initial_prompt"])
|
||||
|
||||
|
||||
class LockTests(unittest.TestCase):
|
||||
def test_lock_rejects_second_instance(self):
|
||||
|
|
|
|||
76
tests/test_vocabulary.py
Normal file
76
tests/test_vocabulary.py
Normal file
|
|
@ -0,0 +1,76 @@
|
|||
import sys
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
SRC = ROOT / "src"
|
||||
if str(SRC) not in sys.path:
|
||||
sys.path.insert(0, str(SRC))
|
||||
|
||||
from config import DomainInferenceConfig, VocabularyConfig, VocabularyReplacement
|
||||
from vocabulary import DOMAIN_GENERAL, VocabularyEngine
|
||||
|
||||
|
||||
class VocabularyEngineTests(unittest.TestCase):
|
||||
def _engine(self, replacements=None, terms=None, domain_enabled=True):
|
||||
vocab = VocabularyConfig(
|
||||
replacements=replacements or [],
|
||||
terms=terms or [],
|
||||
)
|
||||
domain = DomainInferenceConfig(enabled=domain_enabled, mode="auto")
|
||||
return VocabularyEngine(vocab, domain)
|
||||
|
||||
def test_boundary_aware_replacement(self):
|
||||
engine = self._engine(
|
||||
replacements=[VocabularyReplacement(source="Martha", target="Marta")],
|
||||
)
|
||||
|
||||
text = "Martha met Marthaville and Martha."
|
||||
out = engine.apply_deterministic_replacements(text)
|
||||
|
||||
self.assertEqual(out, "Marta met Marthaville and Marta.")
|
||||
|
||||
def test_longest_match_replacement_wins(self):
|
||||
engine = self._engine(
|
||||
replacements=[
|
||||
VocabularyReplacement(source="new york", target="NYC"),
|
||||
VocabularyReplacement(source="york", target="Yorkshire"),
|
||||
],
|
||||
)
|
||||
|
||||
out = engine.apply_deterministic_replacements("new york york")
|
||||
self.assertEqual(out, "NYC Yorkshire")
|
||||
|
||||
def test_stt_hints_are_bounded(self):
|
||||
terms = [f"term{i}" for i in range(300)]
|
||||
engine = self._engine(terms=terms)
|
||||
|
||||
hotwords, prompt = engine.build_stt_hints()
|
||||
|
||||
self.assertLessEqual(len(hotwords), 1024)
|
||||
self.assertLessEqual(len(prompt), 600)
|
||||
|
||||
def test_domain_inference_general_fallback(self):
|
||||
engine = self._engine()
|
||||
result = engine.infer_domain("please call me later")
|
||||
|
||||
self.assertEqual(result.name, DOMAIN_GENERAL)
|
||||
self.assertEqual(result.confidence, 0.0)
|
||||
|
||||
def test_domain_inference_for_technical_text(self):
|
||||
engine = self._engine(terms=["Docker", "Systemd"])
|
||||
result = engine.infer_domain("restart Docker and systemd service on prod")
|
||||
|
||||
self.assertNotEqual(result.name, DOMAIN_GENERAL)
|
||||
self.assertGreater(result.confidence, 0.0)
|
||||
|
||||
def test_domain_inference_can_be_disabled(self):
|
||||
engine = self._engine(domain_enabled=False)
|
||||
result = engine.infer_domain("please restart docker")
|
||||
|
||||
self.assertEqual(result.name, DOMAIN_GENERAL)
|
||||
self.assertEqual(result.confidence, 0.0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Loading…
Add table
Add a link
Reference in a new issue