Add vocabulary correction pipeline and example config
This commit is contained in:
parent
f9224621fa
commit
c3503fbbde
9 changed files with 865 additions and 23 deletions
76
tests/test_vocabulary.py
Normal file
76
tests/test_vocabulary.py
Normal file
|
|
@ -0,0 +1,76 @@
|
|||
import sys
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
SRC = ROOT / "src"
|
||||
if str(SRC) not in sys.path:
|
||||
sys.path.insert(0, str(SRC))
|
||||
|
||||
from config import DomainInferenceConfig, VocabularyConfig, VocabularyReplacement
|
||||
from vocabulary import DOMAIN_GENERAL, VocabularyEngine
|
||||
|
||||
|
||||
class VocabularyEngineTests(unittest.TestCase):
|
||||
def _engine(self, replacements=None, terms=None, domain_enabled=True):
|
||||
vocab = VocabularyConfig(
|
||||
replacements=replacements or [],
|
||||
terms=terms or [],
|
||||
)
|
||||
domain = DomainInferenceConfig(enabled=domain_enabled, mode="auto")
|
||||
return VocabularyEngine(vocab, domain)
|
||||
|
||||
def test_boundary_aware_replacement(self):
|
||||
engine = self._engine(
|
||||
replacements=[VocabularyReplacement(source="Martha", target="Marta")],
|
||||
)
|
||||
|
||||
text = "Martha met Marthaville and Martha."
|
||||
out = engine.apply_deterministic_replacements(text)
|
||||
|
||||
self.assertEqual(out, "Marta met Marthaville and Marta.")
|
||||
|
||||
def test_longest_match_replacement_wins(self):
|
||||
engine = self._engine(
|
||||
replacements=[
|
||||
VocabularyReplacement(source="new york", target="NYC"),
|
||||
VocabularyReplacement(source="york", target="Yorkshire"),
|
||||
],
|
||||
)
|
||||
|
||||
out = engine.apply_deterministic_replacements("new york york")
|
||||
self.assertEqual(out, "NYC Yorkshire")
|
||||
|
||||
def test_stt_hints_are_bounded(self):
|
||||
terms = [f"term{i}" for i in range(300)]
|
||||
engine = self._engine(terms=terms)
|
||||
|
||||
hotwords, prompt = engine.build_stt_hints()
|
||||
|
||||
self.assertLessEqual(len(hotwords), 1024)
|
||||
self.assertLessEqual(len(prompt), 600)
|
||||
|
||||
def test_domain_inference_general_fallback(self):
|
||||
engine = self._engine()
|
||||
result = engine.infer_domain("please call me later")
|
||||
|
||||
self.assertEqual(result.name, DOMAIN_GENERAL)
|
||||
self.assertEqual(result.confidence, 0.0)
|
||||
|
||||
def test_domain_inference_for_technical_text(self):
|
||||
engine = self._engine(terms=["Docker", "Systemd"])
|
||||
result = engine.infer_domain("restart Docker and systemd service on prod")
|
||||
|
||||
self.assertNotEqual(result.name, DOMAIN_GENERAL)
|
||||
self.assertGreater(result.confidence, 0.0)
|
||||
|
||||
def test_domain_inference_can_be_disabled(self):
|
||||
engine = self._engine(domain_enabled=False)
|
||||
result = engine.infer_domain("please restart docker")
|
||||
|
||||
self.assertEqual(result.name, DOMAIN_GENERAL)
|
||||
self.assertEqual(result.confidence, 0.0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Loading…
Add table
Add a link
Reference in a new issue