123 lines
3.9 KiB
Python
123 lines
3.9 KiB
Python
from __future__ import annotations
|
|
|
|
import re
|
|
from dataclasses import dataclass
|
|
from typing import Iterable
|
|
|
|
from config import VocabularyConfig
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class _ReplacementView:
|
|
source: str
|
|
target: str
|
|
|
|
|
|
class VocabularyEngine:
|
|
def __init__(self, vocab_cfg: VocabularyConfig):
|
|
self._replacements = [_ReplacementView(r.source, r.target) for r in vocab_cfg.replacements]
|
|
self._terms = list(vocab_cfg.terms)
|
|
|
|
self._replacement_map = {
|
|
_normalize_key(rule.source): rule.target for rule in self._replacements
|
|
}
|
|
self._replacement_pattern = _build_replacement_pattern(rule.source for rule in self._replacements)
|
|
|
|
# Keep hint payload bounded so model prompts do not balloon.
|
|
self._stt_hotwords = self._build_stt_hotwords(limit=128, char_budget=1024)
|
|
self._stt_initial_prompt = self._build_stt_initial_prompt(char_budget=600)
|
|
|
|
def has_dictionary(self) -> bool:
|
|
return bool(self._replacements or self._terms)
|
|
|
|
def apply_deterministic_replacements(self, text: str) -> str:
|
|
if not text or self._replacement_pattern is None:
|
|
return text
|
|
|
|
def _replace(match: re.Match[str]) -> str:
|
|
source_text = match.group(0)
|
|
key = _normalize_key(source_text)
|
|
return self._replacement_map.get(key, source_text)
|
|
|
|
return self._replacement_pattern.sub(_replace, text)
|
|
|
|
def build_stt_hints(self) -> tuple[str, str]:
|
|
return self._stt_hotwords, self._stt_initial_prompt
|
|
|
|
def build_ai_dictionary_context(self, max_lines: int = 80, char_budget: int = 1500) -> str:
|
|
lines: list[str] = []
|
|
for rule in self._replacements:
|
|
lines.append(f"replace: {rule.source} -> {rule.target}")
|
|
for term in self._terms:
|
|
lines.append(f"prefer: {term}")
|
|
|
|
if not lines:
|
|
return ""
|
|
|
|
out: list[str] = []
|
|
used = 0
|
|
for line in lines:
|
|
if len(out) >= max_lines:
|
|
break
|
|
addition = len(line) + (1 if out else 0)
|
|
if used + addition > char_budget:
|
|
break
|
|
out.append(line)
|
|
used += addition
|
|
return "\n".join(out)
|
|
|
|
def _build_stt_hotwords(self, *, limit: int, char_budget: int) -> str:
|
|
items = _dedupe_preserve_order(
|
|
[rule.target for rule in self._replacements] + self._terms
|
|
)
|
|
words: list[str] = []
|
|
used = 0
|
|
for item in items:
|
|
if len(words) >= limit:
|
|
break
|
|
addition = len(item) + (2 if words else 0)
|
|
if used + addition > char_budget:
|
|
break
|
|
words.append(item)
|
|
used += addition
|
|
return ", ".join(words)
|
|
|
|
def _build_stt_initial_prompt(self, *, char_budget: int) -> str:
|
|
if not self._stt_hotwords:
|
|
return ""
|
|
prefix = "Preferred vocabulary: "
|
|
available = max(char_budget - len(prefix), 0)
|
|
hotwords = self._stt_hotwords[:available].rstrip(", ")
|
|
if not hotwords:
|
|
return ""
|
|
return prefix + hotwords
|
|
|
|
|
|
def _build_replacement_pattern(sources: Iterable[str]) -> re.Pattern[str] | None:
|
|
unique_sources = _dedupe_preserve_order(list(sources))
|
|
if not unique_sources:
|
|
return None
|
|
|
|
unique_sources.sort(key=lambda item: (-len(item), item.casefold()))
|
|
escaped = [re.escape(item) for item in unique_sources]
|
|
pattern = r"(?<!\w)(" + "|".join(escaped) + r")(?!\w)"
|
|
return re.compile(pattern, flags=re.IGNORECASE)
|
|
|
|
|
|
def _dedupe_preserve_order(values: list[str]) -> list[str]:
|
|
out: list[str] = []
|
|
seen: set[str] = set()
|
|
for value in values:
|
|
cleaned = value.strip()
|
|
if not cleaned:
|
|
continue
|
|
key = _normalize_key(cleaned)
|
|
if key in seen:
|
|
continue
|
|
seen.add(key)
|
|
out.append(cleaned)
|
|
return out
|
|
|
|
|
|
def _normalize_key(value: str) -> str:
|
|
return " ".join(value.casefold().split())
|