Add vocabulary correction pipeline and example config
This commit is contained in:
parent
f9224621fa
commit
c3503fbbde
9 changed files with 865 additions and 23 deletions
280
src/vocabulary.py
Normal file
280
src/vocabulary.py
Normal file
|
|
@ -0,0 +1,280 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable
|
||||
|
||||
from config import DomainInferenceConfig, VocabularyConfig
|
||||
|
||||
|
||||
DOMAIN_GENERAL = "general"
|
||||
DOMAIN_PERSONAL_NAMES = "personal_names"
|
||||
DOMAIN_SOFTWARE_DEV = "software_dev"
|
||||
DOMAIN_OPS_INFRA = "ops_infra"
|
||||
DOMAIN_BUSINESS = "business"
|
||||
DOMAIN_MEDICAL_LEGAL = "medical_legal"
|
||||
|
||||
DOMAIN_ORDER = (
|
||||
DOMAIN_PERSONAL_NAMES,
|
||||
DOMAIN_SOFTWARE_DEV,
|
||||
DOMAIN_OPS_INFRA,
|
||||
DOMAIN_BUSINESS,
|
||||
DOMAIN_MEDICAL_LEGAL,
|
||||
)
|
||||
|
||||
DOMAIN_KEYWORDS = {
|
||||
DOMAIN_SOFTWARE_DEV: {
|
||||
"api",
|
||||
"bug",
|
||||
"code",
|
||||
"commit",
|
||||
"docker",
|
||||
"function",
|
||||
"git",
|
||||
"github",
|
||||
"javascript",
|
||||
"python",
|
||||
"refactor",
|
||||
"repository",
|
||||
"typescript",
|
||||
"unit",
|
||||
"test",
|
||||
},
|
||||
DOMAIN_OPS_INFRA: {
|
||||
"cluster",
|
||||
"container",
|
||||
"deploy",
|
||||
"deployment",
|
||||
"incident",
|
||||
"kubernetes",
|
||||
"monitoring",
|
||||
"nginx",
|
||||
"pod",
|
||||
"prod",
|
||||
"service",
|
||||
"systemd",
|
||||
"terraform",
|
||||
},
|
||||
DOMAIN_BUSINESS: {
|
||||
"budget",
|
||||
"client",
|
||||
"deadline",
|
||||
"finance",
|
||||
"invoice",
|
||||
"meeting",
|
||||
"milestone",
|
||||
"project",
|
||||
"quarter",
|
||||
"roadmap",
|
||||
"sales",
|
||||
"stakeholder",
|
||||
},
|
||||
DOMAIN_MEDICAL_LEGAL: {
|
||||
"agreement",
|
||||
"case",
|
||||
"claim",
|
||||
"compliance",
|
||||
"contract",
|
||||
"diagnosis",
|
||||
"liability",
|
||||
"patient",
|
||||
"prescription",
|
||||
"regulation",
|
||||
"symptom",
|
||||
"treatment",
|
||||
},
|
||||
}
|
||||
|
||||
DOMAIN_PHRASES = {
|
||||
DOMAIN_SOFTWARE_DEV: ("pull request", "code review", "integration test"),
|
||||
DOMAIN_OPS_INFRA: ("on call", "service restart", "roll back"),
|
||||
DOMAIN_BUSINESS: ("follow up", "action items", "meeting notes"),
|
||||
DOMAIN_MEDICAL_LEGAL: ("terms and conditions", "medical record", "legal review"),
|
||||
}
|
||||
|
||||
GREETING_TOKENS = {"hello", "hi", "hey", "good morning", "good afternoon", "good evening"}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DomainResult:
|
||||
name: str
|
||||
confidence: float
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _ReplacementView:
|
||||
source: str
|
||||
target: str
|
||||
|
||||
|
||||
class VocabularyEngine:
|
||||
def __init__(self, vocab_cfg: VocabularyConfig, domain_cfg: DomainInferenceConfig):
|
||||
self._replacements = [_ReplacementView(r.source, r.target) for r in vocab_cfg.replacements]
|
||||
self._terms = list(vocab_cfg.terms)
|
||||
self._domain_enabled = bool(domain_cfg.enabled)
|
||||
|
||||
self._replacement_map = {
|
||||
_normalize_key(rule.source): rule.target for rule in self._replacements
|
||||
}
|
||||
self._replacement_pattern = _build_replacement_pattern(rule.source for rule in self._replacements)
|
||||
|
||||
# Keep hint payload bounded so model prompts do not balloon.
|
||||
self._stt_hotwords = self._build_stt_hotwords(limit=128, char_budget=1024)
|
||||
self._stt_initial_prompt = self._build_stt_initial_prompt(char_budget=600)
|
||||
|
||||
def has_dictionary(self) -> bool:
|
||||
return bool(self._replacements or self._terms)
|
||||
|
||||
def apply_deterministic_replacements(self, text: str) -> str:
|
||||
if not text or self._replacement_pattern is None:
|
||||
return text
|
||||
|
||||
def _replace(match: re.Match[str]) -> str:
|
||||
source_text = match.group(0)
|
||||
key = _normalize_key(source_text)
|
||||
return self._replacement_map.get(key, source_text)
|
||||
|
||||
return self._replacement_pattern.sub(_replace, text)
|
||||
|
||||
def build_stt_hints(self) -> tuple[str, str]:
|
||||
return self._stt_hotwords, self._stt_initial_prompt
|
||||
|
||||
def build_ai_dictionary_context(self, max_lines: int = 80, char_budget: int = 1500) -> str:
|
||||
lines: list[str] = []
|
||||
for rule in self._replacements:
|
||||
lines.append(f"replace: {rule.source} -> {rule.target}")
|
||||
for term in self._terms:
|
||||
lines.append(f"prefer: {term}")
|
||||
|
||||
if not lines:
|
||||
return ""
|
||||
|
||||
out: list[str] = []
|
||||
used = 0
|
||||
for line in lines:
|
||||
if len(out) >= max_lines:
|
||||
break
|
||||
addition = len(line) + (1 if out else 0)
|
||||
if used + addition > char_budget:
|
||||
break
|
||||
out.append(line)
|
||||
used += addition
|
||||
return "\n".join(out)
|
||||
|
||||
def infer_domain(self, text: str) -> DomainResult:
|
||||
if not self._domain_enabled:
|
||||
return DomainResult(name=DOMAIN_GENERAL, confidence=0.0)
|
||||
|
||||
normalized = text.casefold()
|
||||
tokens = re.findall(r"[a-z0-9+#./_-]+", normalized)
|
||||
if not tokens:
|
||||
return DomainResult(name=DOMAIN_GENERAL, confidence=0.0)
|
||||
|
||||
scores = {domain: 0 for domain in DOMAIN_ORDER}
|
||||
for token in tokens:
|
||||
for domain, keywords in DOMAIN_KEYWORDS.items():
|
||||
if token in keywords:
|
||||
scores[domain] += 2
|
||||
|
||||
for domain, phrases in DOMAIN_PHRASES.items():
|
||||
for phrase in phrases:
|
||||
if phrase in normalized:
|
||||
scores[domain] += 2
|
||||
|
||||
if any(token in GREETING_TOKENS for token in tokens):
|
||||
scores[DOMAIN_PERSONAL_NAMES] += 1
|
||||
|
||||
# Boost domains from configured dictionary terms and replacement targets.
|
||||
dictionary_tokens = self._dictionary_tokens()
|
||||
for token in dictionary_tokens:
|
||||
for domain, keywords in DOMAIN_KEYWORDS.items():
|
||||
if token in keywords and token in tokens:
|
||||
scores[domain] += 1
|
||||
|
||||
top_domain = DOMAIN_GENERAL
|
||||
top_score = 0
|
||||
total_score = 0
|
||||
for domain in DOMAIN_ORDER:
|
||||
score = scores[domain]
|
||||
total_score += score
|
||||
if score > top_score:
|
||||
top_score = score
|
||||
top_domain = domain
|
||||
|
||||
if top_score < 2 or total_score == 0:
|
||||
return DomainResult(name=DOMAIN_GENERAL, confidence=0.0)
|
||||
|
||||
confidence = top_score / total_score
|
||||
if confidence < 0.45:
|
||||
return DomainResult(name=DOMAIN_GENERAL, confidence=0.0)
|
||||
|
||||
return DomainResult(name=top_domain, confidence=round(confidence, 2))
|
||||
|
||||
def _build_stt_hotwords(self, *, limit: int, char_budget: int) -> str:
|
||||
items = _dedupe_preserve_order(
|
||||
[rule.target for rule in self._replacements] + self._terms
|
||||
)
|
||||
words: list[str] = []
|
||||
used = 0
|
||||
for item in items:
|
||||
if len(words) >= limit:
|
||||
break
|
||||
addition = len(item) + (2 if words else 0)
|
||||
if used + addition > char_budget:
|
||||
break
|
||||
words.append(item)
|
||||
used += addition
|
||||
return ", ".join(words)
|
||||
|
||||
def _build_stt_initial_prompt(self, *, char_budget: int) -> str:
|
||||
if not self._stt_hotwords:
|
||||
return ""
|
||||
prefix = "Preferred vocabulary: "
|
||||
available = max(char_budget - len(prefix), 0)
|
||||
hotwords = self._stt_hotwords[:available].rstrip(", ")
|
||||
if not hotwords:
|
||||
return ""
|
||||
return prefix + hotwords
|
||||
|
||||
def _dictionary_tokens(self) -> set[str]:
|
||||
values: list[str] = []
|
||||
for rule in self._replacements:
|
||||
values.append(rule.source)
|
||||
values.append(rule.target)
|
||||
values.extend(self._terms)
|
||||
|
||||
tokens: set[str] = set()
|
||||
for value in values:
|
||||
for token in re.findall(r"[a-z0-9+#./_-]+", value.casefold()):
|
||||
tokens.add(token)
|
||||
return tokens
|
||||
|
||||
|
||||
def _build_replacement_pattern(sources: Iterable[str]) -> re.Pattern[str] | None:
|
||||
unique_sources = _dedupe_preserve_order(list(sources))
|
||||
if not unique_sources:
|
||||
return None
|
||||
|
||||
unique_sources.sort(key=lambda item: (-len(item), item.casefold()))
|
||||
escaped = [re.escape(item) for item in unique_sources]
|
||||
pattern = r"(?<!\w)(" + "|".join(escaped) + r")(?!\w)"
|
||||
return re.compile(pattern, flags=re.IGNORECASE)
|
||||
|
||||
|
||||
def _dedupe_preserve_order(values: list[str]) -> list[str]:
|
||||
out: list[str] = []
|
||||
seen: set[str] = set()
|
||||
for value in values:
|
||||
cleaned = value.strip()
|
||||
if not cleaned:
|
||||
continue
|
||||
key = _normalize_key(cleaned)
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
out.append(cleaned)
|
||||
return out
|
||||
|
||||
|
||||
def _normalize_key(value: str) -> str:
|
||||
return " ".join(value.casefold().split())
|
||||
Loading…
Add table
Add a link
Reference in a new issue