Add benchmark-driven model promotion workflow and pipeline stages
Some checks failed
ci / test-and-build (push) Has been cancelled
Some checks failed
ci / test-and-build (push) Has been cancelled
This commit is contained in:
parent
98b13d1069
commit
8c1f7c1e13
38 changed files with 5300 additions and 503 deletions
19
src/stages/__init__.py
Normal file
19
src/stages/__init__.py
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
from .alignment_edits import AlignmentDecision, AlignmentHeuristicEngine, AlignmentResult
|
||||
from .asr_whisper import AsrResult, AsrSegment, AsrWord, WhisperAsrStage
|
||||
from .editor_llama import EditorResult, LlamaEditorStage
|
||||
from .fact_guard import FactGuardEngine, FactGuardResult, FactGuardViolation
|
||||
|
||||
__all__ = [
|
||||
"AlignmentDecision",
|
||||
"AlignmentHeuristicEngine",
|
||||
"AlignmentResult",
|
||||
"AsrResult",
|
||||
"AsrSegment",
|
||||
"AsrWord",
|
||||
"WhisperAsrStage",
|
||||
"EditorResult",
|
||||
"LlamaEditorStage",
|
||||
"FactGuardEngine",
|
||||
"FactGuardResult",
|
||||
"FactGuardViolation",
|
||||
]
|
||||
298
src/stages/alignment_edits.py
Normal file
298
src/stages/alignment_edits.py
Normal file
|
|
@ -0,0 +1,298 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable
|
||||
|
||||
from stages.asr_whisper import AsrWord
|
||||
|
||||
|
||||
_MAX_CUE_GAP_S = 0.85
|
||||
_MAX_CORRECTION_WORDS = 8
|
||||
_MAX_LEFT_CONTEXT_WORDS = 8
|
||||
_MAX_PHRASE_GAP_S = 0.9
|
||||
|
||||
|
||||
@dataclass
|
||||
class AlignmentDecision:
|
||||
rule_id: str
|
||||
span_start: int
|
||||
span_end: int
|
||||
replacement: str
|
||||
confidence: str
|
||||
reason: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class AlignmentResult:
|
||||
draft_text: str
|
||||
decisions: list[AlignmentDecision]
|
||||
applied_count: int
|
||||
skipped_count: int
|
||||
|
||||
|
||||
class AlignmentHeuristicEngine:
|
||||
def apply(self, transcript: str, words: list[AsrWord]) -> AlignmentResult:
|
||||
base_text = (transcript or "").strip()
|
||||
if not base_text or not words:
|
||||
return AlignmentResult(
|
||||
draft_text=base_text,
|
||||
decisions=[],
|
||||
applied_count=0,
|
||||
skipped_count=0,
|
||||
)
|
||||
|
||||
normalized_words = [_normalize_token(word.text) for word in words]
|
||||
literal_guard = _has_literal_guard(base_text)
|
||||
out_tokens: list[str] = []
|
||||
decisions: list[AlignmentDecision] = []
|
||||
i = 0
|
||||
while i < len(words):
|
||||
cue = _match_cue(words, normalized_words, i)
|
||||
if cue is not None and out_tokens:
|
||||
cue_len, cue_label = cue
|
||||
correction_start = i + cue_len
|
||||
correction_end = _capture_phrase_end(words, correction_start)
|
||||
if correction_end <= correction_start:
|
||||
decisions.append(
|
||||
AlignmentDecision(
|
||||
rule_id="cue_correction",
|
||||
span_start=i,
|
||||
span_end=min(i + cue_len, len(words)),
|
||||
replacement="",
|
||||
confidence="low",
|
||||
reason=f"{cue_label} has no correction phrase",
|
||||
)
|
||||
)
|
||||
i += cue_len
|
||||
continue
|
||||
correction_tokens = _slice_clean_words(words, correction_start, correction_end)
|
||||
if not correction_tokens:
|
||||
i = correction_end
|
||||
continue
|
||||
left_start = _find_left_context_start(out_tokens)
|
||||
left_tokens = out_tokens[left_start:]
|
||||
candidate = _compose_replacement(left_tokens, correction_tokens)
|
||||
if candidate is None:
|
||||
decisions.append(
|
||||
AlignmentDecision(
|
||||
rule_id="cue_correction",
|
||||
span_start=i,
|
||||
span_end=correction_end,
|
||||
replacement="",
|
||||
confidence="low",
|
||||
reason=f"{cue_label} ambiguous; preserving literal",
|
||||
)
|
||||
)
|
||||
i += 1
|
||||
continue
|
||||
if literal_guard and cue_label == "i mean":
|
||||
decisions.append(
|
||||
AlignmentDecision(
|
||||
rule_id="cue_correction",
|
||||
span_start=i,
|
||||
span_end=correction_end,
|
||||
replacement="",
|
||||
confidence="low",
|
||||
reason="literal dictation context; preserving 'i mean'",
|
||||
)
|
||||
)
|
||||
i += 1
|
||||
continue
|
||||
|
||||
out_tokens = out_tokens[:left_start] + candidate
|
||||
decisions.append(
|
||||
AlignmentDecision(
|
||||
rule_id="cue_correction",
|
||||
span_start=i,
|
||||
span_end=correction_end,
|
||||
replacement=" ".join(candidate),
|
||||
confidence="high",
|
||||
reason=f"{cue_label} replaces prior phrase with correction phrase",
|
||||
)
|
||||
)
|
||||
i = correction_end
|
||||
continue
|
||||
|
||||
token = _strip_token(words[i].text)
|
||||
if token:
|
||||
out_tokens.append(token)
|
||||
i += 1
|
||||
|
||||
out_tokens, repeat_decisions = _collapse_restarts(out_tokens)
|
||||
decisions.extend(repeat_decisions)
|
||||
|
||||
applied_count = sum(1 for decision in decisions if decision.confidence == "high")
|
||||
skipped_count = len(decisions) - applied_count
|
||||
if applied_count <= 0:
|
||||
return AlignmentResult(
|
||||
draft_text=base_text,
|
||||
decisions=decisions,
|
||||
applied_count=0,
|
||||
skipped_count=skipped_count,
|
||||
)
|
||||
|
||||
return AlignmentResult(
|
||||
draft_text=_render_words(out_tokens),
|
||||
decisions=decisions,
|
||||
applied_count=applied_count,
|
||||
skipped_count=skipped_count,
|
||||
)
|
||||
|
||||
|
||||
def _match_cue(words: list[AsrWord], normalized_words: list[str], index: int) -> tuple[int, str] | None:
|
||||
if index >= len(words):
|
||||
return None
|
||||
current = normalized_words[index]
|
||||
if current == "i" and index + 1 < len(words) and normalized_words[index + 1] == "mean":
|
||||
if _gap_since_previous(words, index) <= _MAX_CUE_GAP_S:
|
||||
return (2, "i mean")
|
||||
return None
|
||||
if current == "actually":
|
||||
return (1, "actually")
|
||||
if current == "sorry":
|
||||
return (1, "sorry")
|
||||
if current == "no":
|
||||
raw = words[index].text.strip().lower()
|
||||
if raw in {"no", "no,"}:
|
||||
return (1, "no")
|
||||
return None
|
||||
|
||||
|
||||
def _gap_since_previous(words: list[AsrWord], index: int) -> float:
|
||||
if index <= 0:
|
||||
return 0.0
|
||||
previous = words[index - 1]
|
||||
current = words[index]
|
||||
if previous.end_s <= 0.0 or current.start_s <= 0.0:
|
||||
return 0.0
|
||||
return max(current.start_s - previous.end_s, 0.0)
|
||||
|
||||
|
||||
def _capture_phrase_end(words: list[AsrWord], start: int) -> int:
|
||||
end = start
|
||||
while end < len(words):
|
||||
if end - start >= _MAX_CORRECTION_WORDS:
|
||||
break
|
||||
token = words[end].text.strip()
|
||||
if not token:
|
||||
end += 1
|
||||
continue
|
||||
end += 1
|
||||
if _ends_sentence(token):
|
||||
break
|
||||
if end < len(words):
|
||||
gap = max(words[end].start_s - words[end - 1].end_s, 0.0)
|
||||
if gap > _MAX_PHRASE_GAP_S:
|
||||
break
|
||||
return end
|
||||
|
||||
|
||||
def _find_left_context_start(tokens: list[str]) -> int:
|
||||
start = len(tokens)
|
||||
consumed = 0
|
||||
while start > 0 and consumed < _MAX_LEFT_CONTEXT_WORDS:
|
||||
if _ends_sentence(tokens[start - 1]):
|
||||
break
|
||||
start -= 1
|
||||
consumed += 1
|
||||
return start
|
||||
|
||||
|
||||
def _compose_replacement(left_tokens: list[str], correction_tokens: list[str]) -> list[str] | None:
|
||||
if not left_tokens or not correction_tokens:
|
||||
return None
|
||||
|
||||
left_norm = [_normalize_token(token) for token in left_tokens]
|
||||
right_norm = [_normalize_token(token) for token in correction_tokens]
|
||||
if not any(right_norm):
|
||||
return None
|
||||
|
||||
prefix = 0
|
||||
for lhs, rhs in zip(left_norm, right_norm):
|
||||
if lhs != rhs:
|
||||
break
|
||||
prefix += 1
|
||||
if prefix > 0:
|
||||
return left_tokens[:prefix] + correction_tokens[prefix:]
|
||||
|
||||
if len(correction_tokens) <= 2 and len(left_tokens) >= 1:
|
||||
keep = max(len(left_tokens) - len(correction_tokens), 0)
|
||||
return left_tokens[:keep] + correction_tokens
|
||||
|
||||
if len(correction_tokens) == 1 and len(left_tokens) >= 2:
|
||||
return left_tokens[:-1] + correction_tokens
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _collapse_restarts(tokens: list[str]) -> tuple[list[str], list[AlignmentDecision]]:
|
||||
if len(tokens) < 6:
|
||||
return tokens, []
|
||||
mutable = list(tokens)
|
||||
decisions: list[AlignmentDecision] = []
|
||||
changed = True
|
||||
while changed:
|
||||
changed = False
|
||||
max_chunk = min(6, len(mutable) // 2)
|
||||
for chunk_size in range(max_chunk, 2, -1):
|
||||
index = 0
|
||||
while index + (2 * chunk_size) <= len(mutable):
|
||||
left = [_normalize_token(item) for item in mutable[index : index + chunk_size]]
|
||||
right = [_normalize_token(item) for item in mutable[index + chunk_size : index + (2 * chunk_size)]]
|
||||
if left == right and left:
|
||||
replacement = " ".join(mutable[index + chunk_size : index + (2 * chunk_size)])
|
||||
del mutable[index : index + chunk_size]
|
||||
decisions.append(
|
||||
AlignmentDecision(
|
||||
rule_id="restart_repeat",
|
||||
span_start=index,
|
||||
span_end=index + (2 * chunk_size),
|
||||
replacement=replacement,
|
||||
confidence="high",
|
||||
reason="collapsed exact repeated phrase restart",
|
||||
)
|
||||
)
|
||||
changed = True
|
||||
break
|
||||
index += 1
|
||||
if changed:
|
||||
break
|
||||
return mutable, decisions
|
||||
|
||||
|
||||
def _slice_clean_words(words: list[AsrWord], start: int, end: int) -> list[str]:
|
||||
return [token for token in (_strip_token(word.text) for word in words[start:end]) if token]
|
||||
|
||||
|
||||
def _strip_token(text: str) -> str:
|
||||
token = (text or "").strip()
|
||||
if not token:
|
||||
return ""
|
||||
# Remove surrounding punctuation but keep internal apostrophes/hyphens.
|
||||
return token.strip(" \t\n\r\"'`“”‘’.,!?;:()[]{}")
|
||||
|
||||
|
||||
def _normalize_token(text: str) -> str:
|
||||
return _strip_token(text).casefold()
|
||||
|
||||
|
||||
def _render_words(tokens: Iterable[str]) -> str:
|
||||
cleaned = [token.strip() for token in tokens if token and token.strip()]
|
||||
return " ".join(cleaned).strip()
|
||||
|
||||
|
||||
def _ends_sentence(token: str) -> bool:
|
||||
trimmed = (token or "").strip()
|
||||
return trimmed.endswith(".") or trimmed.endswith("!") or trimmed.endswith("?")
|
||||
|
||||
|
||||
def _has_literal_guard(text: str) -> bool:
|
||||
normalized = " ".join((text or "").casefold().split())
|
||||
guards = (
|
||||
"write exactly",
|
||||
"keep this literal",
|
||||
"keep literal",
|
||||
"verbatim",
|
||||
"quote",
|
||||
)
|
||||
return any(guard in normalized for guard in guards)
|
||||
134
src/stages/asr_whisper.py
Normal file
134
src/stages/asr_whisper.py
Normal file
|
|
@ -0,0 +1,134 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import inspect
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable
|
||||
|
||||
|
||||
@dataclass
|
||||
class AsrWord:
|
||||
text: str
|
||||
start_s: float
|
||||
end_s: float
|
||||
prob: float | None
|
||||
|
||||
|
||||
@dataclass
|
||||
class AsrSegment:
|
||||
text: str
|
||||
start_s: float
|
||||
end_s: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class AsrResult:
|
||||
raw_text: str
|
||||
language: str
|
||||
latency_ms: float
|
||||
words: list[AsrWord]
|
||||
segments: list[AsrSegment]
|
||||
|
||||
|
||||
def _is_stt_language_hint_error(exc: Exception) -> bool:
|
||||
text = str(exc).casefold()
|
||||
has_language = "language" in text
|
||||
unsupported = "unsupported" in text or "not supported" in text or "unknown" in text
|
||||
return has_language and unsupported
|
||||
|
||||
|
||||
class WhisperAsrStage:
|
||||
def __init__(
|
||||
self,
|
||||
model: Any,
|
||||
*,
|
||||
configured_language: str,
|
||||
hint_kwargs_provider: Callable[[], dict[str, Any]] | None = None,
|
||||
) -> None:
|
||||
self._model = model
|
||||
self._configured_language = (configured_language or "auto").strip().lower() or "auto"
|
||||
self._hint_kwargs_provider = hint_kwargs_provider or (lambda: {})
|
||||
self._supports_word_timestamps = _supports_parameter(model.transcribe, "word_timestamps")
|
||||
|
||||
def transcribe(self, audio: Any) -> AsrResult:
|
||||
kwargs: dict[str, Any] = {"vad_filter": True}
|
||||
if self._configured_language != "auto":
|
||||
kwargs["language"] = self._configured_language
|
||||
if self._supports_word_timestamps:
|
||||
kwargs["word_timestamps"] = True
|
||||
kwargs.update(self._hint_kwargs_provider())
|
||||
|
||||
effective_language = self._configured_language
|
||||
started = time.perf_counter()
|
||||
try:
|
||||
segments, _info = self._model.transcribe(audio, **kwargs)
|
||||
except Exception as exc:
|
||||
if self._configured_language != "auto" and _is_stt_language_hint_error(exc):
|
||||
logging.warning(
|
||||
"stt language hint '%s' was rejected; falling back to auto-detect",
|
||||
self._configured_language,
|
||||
)
|
||||
fallback_kwargs = dict(kwargs)
|
||||
fallback_kwargs.pop("language", None)
|
||||
segments, _info = self._model.transcribe(audio, **fallback_kwargs)
|
||||
effective_language = "auto"
|
||||
else:
|
||||
raise
|
||||
|
||||
parts: list[str] = []
|
||||
words: list[AsrWord] = []
|
||||
asr_segments: list[AsrSegment] = []
|
||||
for seg in segments:
|
||||
text = (getattr(seg, "text", "") or "").strip()
|
||||
if text:
|
||||
parts.append(text)
|
||||
start_s = float(getattr(seg, "start", 0.0) or 0.0)
|
||||
end_s = float(getattr(seg, "end", 0.0) or 0.0)
|
||||
asr_segments.append(
|
||||
AsrSegment(
|
||||
text=text,
|
||||
start_s=start_s,
|
||||
end_s=end_s,
|
||||
)
|
||||
)
|
||||
segment_words = getattr(seg, "words", None)
|
||||
if not segment_words:
|
||||
continue
|
||||
for word in segment_words:
|
||||
token = (getattr(word, "word", "") or "").strip()
|
||||
if not token:
|
||||
continue
|
||||
words.append(
|
||||
AsrWord(
|
||||
text=token,
|
||||
start_s=float(getattr(word, "start", 0.0) or 0.0),
|
||||
end_s=float(getattr(word, "end", 0.0) or 0.0),
|
||||
prob=_optional_float(getattr(word, "probability", None)),
|
||||
)
|
||||
)
|
||||
latency_ms = (time.perf_counter() - started) * 1000.0
|
||||
return AsrResult(
|
||||
raw_text=" ".join(parts).strip(),
|
||||
language=effective_language,
|
||||
latency_ms=latency_ms,
|
||||
words=words,
|
||||
segments=asr_segments,
|
||||
)
|
||||
|
||||
|
||||
def _supports_parameter(callable_obj: Callable[..., Any], parameter: str) -> bool:
|
||||
try:
|
||||
signature = inspect.signature(callable_obj)
|
||||
except (TypeError, ValueError):
|
||||
return False
|
||||
return parameter in signature.parameters
|
||||
|
||||
|
||||
def _optional_float(value: Any) -> float | None:
|
||||
if value is None:
|
||||
return None
|
||||
try:
|
||||
return float(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
64
src/stages/editor_llama.py
Normal file
64
src/stages/editor_llama.py
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class EditorResult:
|
||||
final_text: str
|
||||
latency_ms: float
|
||||
pass1_ms: float
|
||||
pass2_ms: float
|
||||
|
||||
|
||||
class LlamaEditorStage:
|
||||
def __init__(self, processor: Any, *, profile: str = "default") -> None:
|
||||
self._processor = processor
|
||||
self._profile = (profile or "default").strip().lower() or "default"
|
||||
|
||||
def set_profile(self, profile: str) -> None:
|
||||
self._profile = (profile or "default").strip().lower() or "default"
|
||||
|
||||
def warmup(self) -> None:
|
||||
self._processor.warmup(profile=self._profile)
|
||||
|
||||
def rewrite(
|
||||
self,
|
||||
transcript: str,
|
||||
*,
|
||||
language: str,
|
||||
dictionary_context: str,
|
||||
) -> EditorResult:
|
||||
started = time.perf_counter()
|
||||
if hasattr(self._processor, "process_with_metrics"):
|
||||
final_text, timings = self._processor.process_with_metrics(
|
||||
transcript,
|
||||
lang=language,
|
||||
dictionary_context=dictionary_context,
|
||||
profile=self._profile,
|
||||
)
|
||||
latency_ms = float(getattr(timings, "total_ms", 0.0))
|
||||
if latency_ms <= 0.0:
|
||||
latency_ms = (time.perf_counter() - started) * 1000.0
|
||||
return EditorResult(
|
||||
final_text=(final_text or "").strip(),
|
||||
latency_ms=latency_ms,
|
||||
pass1_ms=float(getattr(timings, "pass1_ms", 0.0)),
|
||||
pass2_ms=float(getattr(timings, "pass2_ms", 0.0)),
|
||||
)
|
||||
|
||||
final_text = self._processor.process(
|
||||
transcript,
|
||||
lang=language,
|
||||
dictionary_context=dictionary_context,
|
||||
profile=self._profile,
|
||||
)
|
||||
latency_ms = (time.perf_counter() - started) * 1000.0
|
||||
return EditorResult(
|
||||
final_text=(final_text or "").strip(),
|
||||
latency_ms=latency_ms,
|
||||
pass1_ms=0.0,
|
||||
pass2_ms=latency_ms,
|
||||
)
|
||||
294
src/stages/fact_guard.py
Normal file
294
src/stages/fact_guard.py
Normal file
|
|
@ -0,0 +1,294 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from difflib import SequenceMatcher
|
||||
|
||||
|
||||
@dataclass
|
||||
class FactGuardViolation:
|
||||
rule_id: str
|
||||
severity: str
|
||||
source_span: str
|
||||
candidate_span: str
|
||||
reason: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class FactGuardResult:
|
||||
final_text: str
|
||||
action: str
|
||||
violations: list[FactGuardViolation]
|
||||
violations_count: int
|
||||
latency_ms: float
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _FactEntity:
|
||||
key: str
|
||||
value: str
|
||||
kind: str
|
||||
severity: str
|
||||
|
||||
|
||||
_URL_RE = re.compile(r"\bhttps?://[^\s<>\"']+")
|
||||
_EMAIL_RE = re.compile(r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}\b")
|
||||
_ID_RE = re.compile(r"\b(?:[A-Z]{2,}[-_]\d+[A-Z0-9]*|[A-Za-z]+\d+[A-Za-z0-9_-]*)\b")
|
||||
_NUMBER_RE = re.compile(r"\b\d+(?:[.,:]\d+)*(?:%|am|pm)?\b", re.IGNORECASE)
|
||||
_TOKEN_RE = re.compile(r"\b[^\s]+\b")
|
||||
_DIFF_TOKEN_RE = re.compile(r"[A-Za-z0-9][A-Za-z0-9'_-]*")
|
||||
|
||||
_SOFT_TOKENS = {
|
||||
"a",
|
||||
"an",
|
||||
"and",
|
||||
"are",
|
||||
"as",
|
||||
"at",
|
||||
"be",
|
||||
"been",
|
||||
"being",
|
||||
"but",
|
||||
"by",
|
||||
"for",
|
||||
"from",
|
||||
"if",
|
||||
"in",
|
||||
"is",
|
||||
"it",
|
||||
"its",
|
||||
"of",
|
||||
"on",
|
||||
"or",
|
||||
"that",
|
||||
"the",
|
||||
"their",
|
||||
"then",
|
||||
"there",
|
||||
"these",
|
||||
"they",
|
||||
"this",
|
||||
"those",
|
||||
"to",
|
||||
"was",
|
||||
"we",
|
||||
"were",
|
||||
"with",
|
||||
"you",
|
||||
"your",
|
||||
}
|
||||
|
||||
_NON_FACT_WORDS = {
|
||||
"please",
|
||||
"hello",
|
||||
"hi",
|
||||
"thanks",
|
||||
"thank",
|
||||
"set",
|
||||
"send",
|
||||
"write",
|
||||
"schedule",
|
||||
"call",
|
||||
"meeting",
|
||||
"email",
|
||||
"message",
|
||||
"note",
|
||||
}
|
||||
|
||||
|
||||
class FactGuardEngine:
|
||||
def apply(
|
||||
self,
|
||||
source_text: str,
|
||||
candidate_text: str,
|
||||
*,
|
||||
enabled: bool,
|
||||
strict: bool,
|
||||
) -> FactGuardResult:
|
||||
started = time.perf_counter()
|
||||
source = (source_text or "").strip()
|
||||
candidate = (candidate_text or "").strip() or source
|
||||
if not enabled:
|
||||
return FactGuardResult(
|
||||
final_text=candidate,
|
||||
action="accepted",
|
||||
violations=[],
|
||||
violations_count=0,
|
||||
latency_ms=(time.perf_counter() - started) * 1000.0,
|
||||
)
|
||||
|
||||
violations: list[FactGuardViolation] = []
|
||||
source_entities = _extract_entities(source)
|
||||
candidate_entities = _extract_entities(candidate)
|
||||
|
||||
for key, entity in source_entities.items():
|
||||
if key in candidate_entities:
|
||||
continue
|
||||
violations.append(
|
||||
FactGuardViolation(
|
||||
rule_id="entity_preservation",
|
||||
severity=entity.severity,
|
||||
source_span=entity.value,
|
||||
candidate_span="",
|
||||
reason=f"candidate dropped source {entity.kind}",
|
||||
)
|
||||
)
|
||||
|
||||
for key, entity in candidate_entities.items():
|
||||
if key in source_entities:
|
||||
continue
|
||||
violations.append(
|
||||
FactGuardViolation(
|
||||
rule_id="entity_invention",
|
||||
severity=entity.severity,
|
||||
source_span="",
|
||||
candidate_span=entity.value,
|
||||
reason=f"candidate introduced new {entity.kind}",
|
||||
)
|
||||
)
|
||||
|
||||
if strict:
|
||||
additions = _strict_additions(source, candidate)
|
||||
if additions:
|
||||
additions_preview = " ".join(additions[:8])
|
||||
violations.append(
|
||||
FactGuardViolation(
|
||||
rule_id="diff_additions",
|
||||
severity="high",
|
||||
source_span="",
|
||||
candidate_span=additions_preview,
|
||||
reason="strict mode blocks substantial lexical additions",
|
||||
)
|
||||
)
|
||||
|
||||
violations = _dedupe_violations(violations)
|
||||
action = "accepted"
|
||||
final_text = candidate
|
||||
if violations:
|
||||
action = "rejected" if strict else "fallback"
|
||||
final_text = source
|
||||
|
||||
return FactGuardResult(
|
||||
final_text=final_text,
|
||||
action=action,
|
||||
violations=violations,
|
||||
violations_count=len(violations),
|
||||
latency_ms=(time.perf_counter() - started) * 1000.0,
|
||||
)
|
||||
|
||||
|
||||
def _extract_entities(text: str) -> dict[str, _FactEntity]:
|
||||
entities: dict[str, _FactEntity] = {}
|
||||
if not text:
|
||||
return entities
|
||||
|
||||
for match in _URL_RE.finditer(text):
|
||||
_add_entity(entities, match.group(0), kind="url", severity="high")
|
||||
for match in _EMAIL_RE.finditer(text):
|
||||
_add_entity(entities, match.group(0), kind="email", severity="high")
|
||||
for match in _ID_RE.finditer(text):
|
||||
_add_entity(entities, match.group(0), kind="identifier", severity="high")
|
||||
for match in _NUMBER_RE.finditer(text):
|
||||
_add_entity(entities, match.group(0), kind="number", severity="high")
|
||||
|
||||
for match in _TOKEN_RE.finditer(text):
|
||||
token = _clean_token(match.group(0))
|
||||
if not token:
|
||||
continue
|
||||
if token.casefold() in _NON_FACT_WORDS:
|
||||
continue
|
||||
if _looks_name_or_term(token):
|
||||
_add_entity(entities, token, kind="name_or_term", severity="medium")
|
||||
return entities
|
||||
|
||||
|
||||
def _add_entity(
|
||||
entities: dict[str, _FactEntity],
|
||||
token: str,
|
||||
*,
|
||||
kind: str,
|
||||
severity: str,
|
||||
) -> None:
|
||||
cleaned = _clean_token(token)
|
||||
if not cleaned:
|
||||
return
|
||||
key = _normalize_key(cleaned)
|
||||
if not key:
|
||||
return
|
||||
if key in entities:
|
||||
return
|
||||
entities[key] = _FactEntity(
|
||||
key=key,
|
||||
value=cleaned,
|
||||
kind=kind,
|
||||
severity=severity,
|
||||
)
|
||||
|
||||
|
||||
def _strict_additions(source_text: str, candidate_text: str) -> list[str]:
|
||||
source_tokens = _diff_tokens(source_text)
|
||||
candidate_tokens = _diff_tokens(candidate_text)
|
||||
if not source_tokens or not candidate_tokens:
|
||||
return []
|
||||
matcher = SequenceMatcher(a=source_tokens, b=candidate_tokens)
|
||||
added: list[str] = []
|
||||
for tag, _i1, _i2, j1, j2 in matcher.get_opcodes():
|
||||
if tag not in {"insert", "replace"}:
|
||||
continue
|
||||
added.extend(candidate_tokens[j1:j2])
|
||||
meaningful = [token for token in added if _is_meaningful_added_token(token)]
|
||||
if not meaningful:
|
||||
return []
|
||||
added_ratio = len(meaningful) / max(len(source_tokens), 1)
|
||||
if len(meaningful) >= 2 and added_ratio >= 0.1:
|
||||
return meaningful
|
||||
return []
|
||||
|
||||
|
||||
def _diff_tokens(text: str) -> list[str]:
|
||||
return [match.group(0).casefold() for match in _DIFF_TOKEN_RE.finditer(text or "")]
|
||||
|
||||
|
||||
def _looks_name_or_term(token: str) -> bool:
|
||||
if len(token) < 2:
|
||||
return False
|
||||
if any(ch.isdigit() for ch in token):
|
||||
return False
|
||||
has_upper = any(ch.isupper() for ch in token)
|
||||
if not has_upper:
|
||||
return False
|
||||
if token.isupper() and len(token) >= 2:
|
||||
return True
|
||||
if token[0].isupper():
|
||||
return True
|
||||
# Mixed-case term like "iPhone".
|
||||
return any(ch.isupper() for ch in token[1:])
|
||||
|
||||
|
||||
def _is_meaningful_added_token(token: str) -> bool:
|
||||
if len(token) <= 1:
|
||||
return False
|
||||
if token in _SOFT_TOKENS:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _clean_token(token: str) -> str:
|
||||
return (token or "").strip(" \t\n\r\"'`.,!?;:()[]{}")
|
||||
|
||||
|
||||
def _normalize_key(value: str) -> str:
|
||||
return " ".join((value or "").casefold().split())
|
||||
|
||||
|
||||
def _dedupe_violations(violations: list[FactGuardViolation]) -> list[FactGuardViolation]:
|
||||
deduped: list[FactGuardViolation] = []
|
||||
seen: set[tuple[str, str, str, str]] = set()
|
||||
for item in violations:
|
||||
key = (item.rule_id, item.severity, item.source_span.casefold(), item.candidate_span.casefold())
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
deduped.append(item)
|
||||
return deduped
|
||||
Loading…
Add table
Add a link
Reference in a new issue