Add benchmark-driven model promotion workflow and pipeline stages
Some checks failed
ci / test-and-build (push) Has been cancelled

This commit is contained in:
Thales Maciel 2026-02-28 15:12:33 -03:00
parent 98b13d1069
commit 8c1f7c1e13
38 changed files with 5300 additions and 503 deletions

19
src/stages/__init__.py Normal file
View file

@ -0,0 +1,19 @@
from .alignment_edits import AlignmentDecision, AlignmentHeuristicEngine, AlignmentResult
from .asr_whisper import AsrResult, AsrSegment, AsrWord, WhisperAsrStage
from .editor_llama import EditorResult, LlamaEditorStage
from .fact_guard import FactGuardEngine, FactGuardResult, FactGuardViolation
__all__ = [
"AlignmentDecision",
"AlignmentHeuristicEngine",
"AlignmentResult",
"AsrResult",
"AsrSegment",
"AsrWord",
"WhisperAsrStage",
"EditorResult",
"LlamaEditorStage",
"FactGuardEngine",
"FactGuardResult",
"FactGuardViolation",
]

View file

@ -0,0 +1,298 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Iterable
from stages.asr_whisper import AsrWord
_MAX_CUE_GAP_S = 0.85
_MAX_CORRECTION_WORDS = 8
_MAX_LEFT_CONTEXT_WORDS = 8
_MAX_PHRASE_GAP_S = 0.9
@dataclass
class AlignmentDecision:
rule_id: str
span_start: int
span_end: int
replacement: str
confidence: str
reason: str
@dataclass
class AlignmentResult:
draft_text: str
decisions: list[AlignmentDecision]
applied_count: int
skipped_count: int
class AlignmentHeuristicEngine:
def apply(self, transcript: str, words: list[AsrWord]) -> AlignmentResult:
base_text = (transcript or "").strip()
if not base_text or not words:
return AlignmentResult(
draft_text=base_text,
decisions=[],
applied_count=0,
skipped_count=0,
)
normalized_words = [_normalize_token(word.text) for word in words]
literal_guard = _has_literal_guard(base_text)
out_tokens: list[str] = []
decisions: list[AlignmentDecision] = []
i = 0
while i < len(words):
cue = _match_cue(words, normalized_words, i)
if cue is not None and out_tokens:
cue_len, cue_label = cue
correction_start = i + cue_len
correction_end = _capture_phrase_end(words, correction_start)
if correction_end <= correction_start:
decisions.append(
AlignmentDecision(
rule_id="cue_correction",
span_start=i,
span_end=min(i + cue_len, len(words)),
replacement="",
confidence="low",
reason=f"{cue_label} has no correction phrase",
)
)
i += cue_len
continue
correction_tokens = _slice_clean_words(words, correction_start, correction_end)
if not correction_tokens:
i = correction_end
continue
left_start = _find_left_context_start(out_tokens)
left_tokens = out_tokens[left_start:]
candidate = _compose_replacement(left_tokens, correction_tokens)
if candidate is None:
decisions.append(
AlignmentDecision(
rule_id="cue_correction",
span_start=i,
span_end=correction_end,
replacement="",
confidence="low",
reason=f"{cue_label} ambiguous; preserving literal",
)
)
i += 1
continue
if literal_guard and cue_label == "i mean":
decisions.append(
AlignmentDecision(
rule_id="cue_correction",
span_start=i,
span_end=correction_end,
replacement="",
confidence="low",
reason="literal dictation context; preserving 'i mean'",
)
)
i += 1
continue
out_tokens = out_tokens[:left_start] + candidate
decisions.append(
AlignmentDecision(
rule_id="cue_correction",
span_start=i,
span_end=correction_end,
replacement=" ".join(candidate),
confidence="high",
reason=f"{cue_label} replaces prior phrase with correction phrase",
)
)
i = correction_end
continue
token = _strip_token(words[i].text)
if token:
out_tokens.append(token)
i += 1
out_tokens, repeat_decisions = _collapse_restarts(out_tokens)
decisions.extend(repeat_decisions)
applied_count = sum(1 for decision in decisions if decision.confidence == "high")
skipped_count = len(decisions) - applied_count
if applied_count <= 0:
return AlignmentResult(
draft_text=base_text,
decisions=decisions,
applied_count=0,
skipped_count=skipped_count,
)
return AlignmentResult(
draft_text=_render_words(out_tokens),
decisions=decisions,
applied_count=applied_count,
skipped_count=skipped_count,
)
def _match_cue(words: list[AsrWord], normalized_words: list[str], index: int) -> tuple[int, str] | None:
if index >= len(words):
return None
current = normalized_words[index]
if current == "i" and index + 1 < len(words) and normalized_words[index + 1] == "mean":
if _gap_since_previous(words, index) <= _MAX_CUE_GAP_S:
return (2, "i mean")
return None
if current == "actually":
return (1, "actually")
if current == "sorry":
return (1, "sorry")
if current == "no":
raw = words[index].text.strip().lower()
if raw in {"no", "no,"}:
return (1, "no")
return None
def _gap_since_previous(words: list[AsrWord], index: int) -> float:
if index <= 0:
return 0.0
previous = words[index - 1]
current = words[index]
if previous.end_s <= 0.0 or current.start_s <= 0.0:
return 0.0
return max(current.start_s - previous.end_s, 0.0)
def _capture_phrase_end(words: list[AsrWord], start: int) -> int:
end = start
while end < len(words):
if end - start >= _MAX_CORRECTION_WORDS:
break
token = words[end].text.strip()
if not token:
end += 1
continue
end += 1
if _ends_sentence(token):
break
if end < len(words):
gap = max(words[end].start_s - words[end - 1].end_s, 0.0)
if gap > _MAX_PHRASE_GAP_S:
break
return end
def _find_left_context_start(tokens: list[str]) -> int:
start = len(tokens)
consumed = 0
while start > 0 and consumed < _MAX_LEFT_CONTEXT_WORDS:
if _ends_sentence(tokens[start - 1]):
break
start -= 1
consumed += 1
return start
def _compose_replacement(left_tokens: list[str], correction_tokens: list[str]) -> list[str] | None:
if not left_tokens or not correction_tokens:
return None
left_norm = [_normalize_token(token) for token in left_tokens]
right_norm = [_normalize_token(token) for token in correction_tokens]
if not any(right_norm):
return None
prefix = 0
for lhs, rhs in zip(left_norm, right_norm):
if lhs != rhs:
break
prefix += 1
if prefix > 0:
return left_tokens[:prefix] + correction_tokens[prefix:]
if len(correction_tokens) <= 2 and len(left_tokens) >= 1:
keep = max(len(left_tokens) - len(correction_tokens), 0)
return left_tokens[:keep] + correction_tokens
if len(correction_tokens) == 1 and len(left_tokens) >= 2:
return left_tokens[:-1] + correction_tokens
return None
def _collapse_restarts(tokens: list[str]) -> tuple[list[str], list[AlignmentDecision]]:
if len(tokens) < 6:
return tokens, []
mutable = list(tokens)
decisions: list[AlignmentDecision] = []
changed = True
while changed:
changed = False
max_chunk = min(6, len(mutable) // 2)
for chunk_size in range(max_chunk, 2, -1):
index = 0
while index + (2 * chunk_size) <= len(mutable):
left = [_normalize_token(item) for item in mutable[index : index + chunk_size]]
right = [_normalize_token(item) for item in mutable[index + chunk_size : index + (2 * chunk_size)]]
if left == right and left:
replacement = " ".join(mutable[index + chunk_size : index + (2 * chunk_size)])
del mutable[index : index + chunk_size]
decisions.append(
AlignmentDecision(
rule_id="restart_repeat",
span_start=index,
span_end=index + (2 * chunk_size),
replacement=replacement,
confidence="high",
reason="collapsed exact repeated phrase restart",
)
)
changed = True
break
index += 1
if changed:
break
return mutable, decisions
def _slice_clean_words(words: list[AsrWord], start: int, end: int) -> list[str]:
return [token for token in (_strip_token(word.text) for word in words[start:end]) if token]
def _strip_token(text: str) -> str:
token = (text or "").strip()
if not token:
return ""
# Remove surrounding punctuation but keep internal apostrophes/hyphens.
return token.strip(" \t\n\r\"'`“”‘’.,!?;:()[]{}")
def _normalize_token(text: str) -> str:
return _strip_token(text).casefold()
def _render_words(tokens: Iterable[str]) -> str:
cleaned = [token.strip() for token in tokens if token and token.strip()]
return " ".join(cleaned).strip()
def _ends_sentence(token: str) -> bool:
trimmed = (token or "").strip()
return trimmed.endswith(".") or trimmed.endswith("!") or trimmed.endswith("?")
def _has_literal_guard(text: str) -> bool:
normalized = " ".join((text or "").casefold().split())
guards = (
"write exactly",
"keep this literal",
"keep literal",
"verbatim",
"quote",
)
return any(guard in normalized for guard in guards)

134
src/stages/asr_whisper.py Normal file
View file

@ -0,0 +1,134 @@
from __future__ import annotations
import logging
import inspect
import time
from dataclasses import dataclass
from typing import Any, Callable
@dataclass
class AsrWord:
text: str
start_s: float
end_s: float
prob: float | None
@dataclass
class AsrSegment:
text: str
start_s: float
end_s: float
@dataclass
class AsrResult:
raw_text: str
language: str
latency_ms: float
words: list[AsrWord]
segments: list[AsrSegment]
def _is_stt_language_hint_error(exc: Exception) -> bool:
text = str(exc).casefold()
has_language = "language" in text
unsupported = "unsupported" in text or "not supported" in text or "unknown" in text
return has_language and unsupported
class WhisperAsrStage:
def __init__(
self,
model: Any,
*,
configured_language: str,
hint_kwargs_provider: Callable[[], dict[str, Any]] | None = None,
) -> None:
self._model = model
self._configured_language = (configured_language or "auto").strip().lower() or "auto"
self._hint_kwargs_provider = hint_kwargs_provider or (lambda: {})
self._supports_word_timestamps = _supports_parameter(model.transcribe, "word_timestamps")
def transcribe(self, audio: Any) -> AsrResult:
kwargs: dict[str, Any] = {"vad_filter": True}
if self._configured_language != "auto":
kwargs["language"] = self._configured_language
if self._supports_word_timestamps:
kwargs["word_timestamps"] = True
kwargs.update(self._hint_kwargs_provider())
effective_language = self._configured_language
started = time.perf_counter()
try:
segments, _info = self._model.transcribe(audio, **kwargs)
except Exception as exc:
if self._configured_language != "auto" and _is_stt_language_hint_error(exc):
logging.warning(
"stt language hint '%s' was rejected; falling back to auto-detect",
self._configured_language,
)
fallback_kwargs = dict(kwargs)
fallback_kwargs.pop("language", None)
segments, _info = self._model.transcribe(audio, **fallback_kwargs)
effective_language = "auto"
else:
raise
parts: list[str] = []
words: list[AsrWord] = []
asr_segments: list[AsrSegment] = []
for seg in segments:
text = (getattr(seg, "text", "") or "").strip()
if text:
parts.append(text)
start_s = float(getattr(seg, "start", 0.0) or 0.0)
end_s = float(getattr(seg, "end", 0.0) or 0.0)
asr_segments.append(
AsrSegment(
text=text,
start_s=start_s,
end_s=end_s,
)
)
segment_words = getattr(seg, "words", None)
if not segment_words:
continue
for word in segment_words:
token = (getattr(word, "word", "") or "").strip()
if not token:
continue
words.append(
AsrWord(
text=token,
start_s=float(getattr(word, "start", 0.0) or 0.0),
end_s=float(getattr(word, "end", 0.0) or 0.0),
prob=_optional_float(getattr(word, "probability", None)),
)
)
latency_ms = (time.perf_counter() - started) * 1000.0
return AsrResult(
raw_text=" ".join(parts).strip(),
language=effective_language,
latency_ms=latency_ms,
words=words,
segments=asr_segments,
)
def _supports_parameter(callable_obj: Callable[..., Any], parameter: str) -> bool:
try:
signature = inspect.signature(callable_obj)
except (TypeError, ValueError):
return False
return parameter in signature.parameters
def _optional_float(value: Any) -> float | None:
if value is None:
return None
try:
return float(value)
except (TypeError, ValueError):
return None

View file

@ -0,0 +1,64 @@
from __future__ import annotations
import time
from dataclasses import dataclass
from typing import Any
@dataclass
class EditorResult:
final_text: str
latency_ms: float
pass1_ms: float
pass2_ms: float
class LlamaEditorStage:
def __init__(self, processor: Any, *, profile: str = "default") -> None:
self._processor = processor
self._profile = (profile or "default").strip().lower() or "default"
def set_profile(self, profile: str) -> None:
self._profile = (profile or "default").strip().lower() or "default"
def warmup(self) -> None:
self._processor.warmup(profile=self._profile)
def rewrite(
self,
transcript: str,
*,
language: str,
dictionary_context: str,
) -> EditorResult:
started = time.perf_counter()
if hasattr(self._processor, "process_with_metrics"):
final_text, timings = self._processor.process_with_metrics(
transcript,
lang=language,
dictionary_context=dictionary_context,
profile=self._profile,
)
latency_ms = float(getattr(timings, "total_ms", 0.0))
if latency_ms <= 0.0:
latency_ms = (time.perf_counter() - started) * 1000.0
return EditorResult(
final_text=(final_text or "").strip(),
latency_ms=latency_ms,
pass1_ms=float(getattr(timings, "pass1_ms", 0.0)),
pass2_ms=float(getattr(timings, "pass2_ms", 0.0)),
)
final_text = self._processor.process(
transcript,
lang=language,
dictionary_context=dictionary_context,
profile=self._profile,
)
latency_ms = (time.perf_counter() - started) * 1000.0
return EditorResult(
final_text=(final_text or "").strip(),
latency_ms=latency_ms,
pass1_ms=0.0,
pass2_ms=latency_ms,
)

294
src/stages/fact_guard.py Normal file
View file

@ -0,0 +1,294 @@
from __future__ import annotations
import re
import time
from dataclasses import dataclass
from difflib import SequenceMatcher
@dataclass
class FactGuardViolation:
rule_id: str
severity: str
source_span: str
candidate_span: str
reason: str
@dataclass
class FactGuardResult:
final_text: str
action: str
violations: list[FactGuardViolation]
violations_count: int
latency_ms: float
@dataclass(frozen=True)
class _FactEntity:
key: str
value: str
kind: str
severity: str
_URL_RE = re.compile(r"\bhttps?://[^\s<>\"']+")
_EMAIL_RE = re.compile(r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}\b")
_ID_RE = re.compile(r"\b(?:[A-Z]{2,}[-_]\d+[A-Z0-9]*|[A-Za-z]+\d+[A-Za-z0-9_-]*)\b")
_NUMBER_RE = re.compile(r"\b\d+(?:[.,:]\d+)*(?:%|am|pm)?\b", re.IGNORECASE)
_TOKEN_RE = re.compile(r"\b[^\s]+\b")
_DIFF_TOKEN_RE = re.compile(r"[A-Za-z0-9][A-Za-z0-9'_-]*")
_SOFT_TOKENS = {
"a",
"an",
"and",
"are",
"as",
"at",
"be",
"been",
"being",
"but",
"by",
"for",
"from",
"if",
"in",
"is",
"it",
"its",
"of",
"on",
"or",
"that",
"the",
"their",
"then",
"there",
"these",
"they",
"this",
"those",
"to",
"was",
"we",
"were",
"with",
"you",
"your",
}
_NON_FACT_WORDS = {
"please",
"hello",
"hi",
"thanks",
"thank",
"set",
"send",
"write",
"schedule",
"call",
"meeting",
"email",
"message",
"note",
}
class FactGuardEngine:
def apply(
self,
source_text: str,
candidate_text: str,
*,
enabled: bool,
strict: bool,
) -> FactGuardResult:
started = time.perf_counter()
source = (source_text or "").strip()
candidate = (candidate_text or "").strip() or source
if not enabled:
return FactGuardResult(
final_text=candidate,
action="accepted",
violations=[],
violations_count=0,
latency_ms=(time.perf_counter() - started) * 1000.0,
)
violations: list[FactGuardViolation] = []
source_entities = _extract_entities(source)
candidate_entities = _extract_entities(candidate)
for key, entity in source_entities.items():
if key in candidate_entities:
continue
violations.append(
FactGuardViolation(
rule_id="entity_preservation",
severity=entity.severity,
source_span=entity.value,
candidate_span="",
reason=f"candidate dropped source {entity.kind}",
)
)
for key, entity in candidate_entities.items():
if key in source_entities:
continue
violations.append(
FactGuardViolation(
rule_id="entity_invention",
severity=entity.severity,
source_span="",
candidate_span=entity.value,
reason=f"candidate introduced new {entity.kind}",
)
)
if strict:
additions = _strict_additions(source, candidate)
if additions:
additions_preview = " ".join(additions[:8])
violations.append(
FactGuardViolation(
rule_id="diff_additions",
severity="high",
source_span="",
candidate_span=additions_preview,
reason="strict mode blocks substantial lexical additions",
)
)
violations = _dedupe_violations(violations)
action = "accepted"
final_text = candidate
if violations:
action = "rejected" if strict else "fallback"
final_text = source
return FactGuardResult(
final_text=final_text,
action=action,
violations=violations,
violations_count=len(violations),
latency_ms=(time.perf_counter() - started) * 1000.0,
)
def _extract_entities(text: str) -> dict[str, _FactEntity]:
entities: dict[str, _FactEntity] = {}
if not text:
return entities
for match in _URL_RE.finditer(text):
_add_entity(entities, match.group(0), kind="url", severity="high")
for match in _EMAIL_RE.finditer(text):
_add_entity(entities, match.group(0), kind="email", severity="high")
for match in _ID_RE.finditer(text):
_add_entity(entities, match.group(0), kind="identifier", severity="high")
for match in _NUMBER_RE.finditer(text):
_add_entity(entities, match.group(0), kind="number", severity="high")
for match in _TOKEN_RE.finditer(text):
token = _clean_token(match.group(0))
if not token:
continue
if token.casefold() in _NON_FACT_WORDS:
continue
if _looks_name_or_term(token):
_add_entity(entities, token, kind="name_or_term", severity="medium")
return entities
def _add_entity(
entities: dict[str, _FactEntity],
token: str,
*,
kind: str,
severity: str,
) -> None:
cleaned = _clean_token(token)
if not cleaned:
return
key = _normalize_key(cleaned)
if not key:
return
if key in entities:
return
entities[key] = _FactEntity(
key=key,
value=cleaned,
kind=kind,
severity=severity,
)
def _strict_additions(source_text: str, candidate_text: str) -> list[str]:
source_tokens = _diff_tokens(source_text)
candidate_tokens = _diff_tokens(candidate_text)
if not source_tokens or not candidate_tokens:
return []
matcher = SequenceMatcher(a=source_tokens, b=candidate_tokens)
added: list[str] = []
for tag, _i1, _i2, j1, j2 in matcher.get_opcodes():
if tag not in {"insert", "replace"}:
continue
added.extend(candidate_tokens[j1:j2])
meaningful = [token for token in added if _is_meaningful_added_token(token)]
if not meaningful:
return []
added_ratio = len(meaningful) / max(len(source_tokens), 1)
if len(meaningful) >= 2 and added_ratio >= 0.1:
return meaningful
return []
def _diff_tokens(text: str) -> list[str]:
return [match.group(0).casefold() for match in _DIFF_TOKEN_RE.finditer(text or "")]
def _looks_name_or_term(token: str) -> bool:
if len(token) < 2:
return False
if any(ch.isdigit() for ch in token):
return False
has_upper = any(ch.isupper() for ch in token)
if not has_upper:
return False
if token.isupper() and len(token) >= 2:
return True
if token[0].isupper():
return True
# Mixed-case term like "iPhone".
return any(ch.isupper() for ch in token[1:])
def _is_meaningful_added_token(token: str) -> bool:
if len(token) <= 1:
return False
if token in _SOFT_TOKENS:
return False
return True
def _clean_token(token: str) -> str:
return (token or "").strip(" \t\n\r\"'`.,!?;:()[]{}")
def _normalize_key(value: str) -> str:
return " ".join((value or "").casefold().split())
def _dedupe_violations(violations: list[FactGuardViolation]) -> list[FactGuardViolation]:
deduped: list[FactGuardViolation] = []
seen: set[tuple[str, str, str, str]] = set()
for item in violations:
key = (item.rule_id, item.severity, item.source_span.casefold(), item.candidate_span.casefold())
if key in seen:
continue
seen.add(key)
deduped.append(item)
return deduped