Simplify editor cleanup and keep live ASR metadata
Some checks are pending
ci / test-and-build (push) Waiting to run

Keep the daemon path on the full ASR result so word timings and detected language survive into the editor pipeline instead of falling back to a plain transcript string.

Add PipelineEngine.run_asr_result(), have aman call it when live ASR data is available, and cover the word-aware alignment behavior in the daemon tests.

Collapse the llama cleanup flow to a single JSON-shaped completion while leaving the legacy pass1/pass2 parameters in place as compatibility no-ops.

Validated with PYTHONPATH=src python3 -m unittest tests.test_aiprocess tests.test_aman.
This commit is contained in:
Thales Maciel 2026-03-12 13:24:36 -03:00
parent 8c1f7c1e13
commit fa91f313c4
No known key found for this signature in database
GPG key ID: 33112E6833C34679
5 changed files with 166 additions and 84 deletions

View file

@ -207,7 +207,29 @@ PASS2_SYSTEM_PROMPT = (
# Keep a stable symbol for documentation and tooling. # Keep a stable symbol for documentation and tooling.
SYSTEM_PROMPT = PASS2_SYSTEM_PROMPT SYSTEM_PROMPT = (
"You are an amanuensis working for an user.\n"
"You'll receive a JSON object with the transcript and optional context.\n"
"Your job is to rewrite the user's transcript into clean prose.\n"
"Your output will be directly pasted in the currently focused application on the user computer.\n\n"
"Rules:\n"
"- Preserve meaning, facts, and intent.\n"
"- Preserve greetings and salutations (Hey, Hi, Hey there, Hello).\n"
"- Preserve wording. Do not replace words for synonyms\n"
"- Do not add new info.\n"
"- Remove filler words (um/uh/like)\n"
"- Remove false starts\n"
"- Remove self-corrections.\n"
"- If a dictionary section exists, apply only the listed corrections.\n"
"- Keep dictionary spellings exactly as provided.\n"
"- Treat domain hints as advisory only; never invent context-specific jargon.\n"
"- Return ONLY valid JSON in this shape: {\"cleaned_text\": \"...\"}\n"
"- Do not wrap with markdown, tags, or extra keys.\n\n"
"Examples:\n"
" - transcript=\"Hey, schedule that for 5 PM, I mean 4 PM\" -> {\"cleaned_text\":\"Hey, schedule that for 4 PM\"}\n"
" - transcript=\"Good morning Martha, nice to meet you!\" -> {\"cleaned_text\":\"Good morning Martha, nice to meet you!\"}\n"
" - transcript=\"let's ask Bob, I mean Janice, let's ask Janice\" -> {\"cleaned_text\":\"let's ask Janice\"}\n"
)
class LlamaProcessor: class LlamaProcessor:
@ -275,15 +297,8 @@ class LlamaProcessor:
min(max_tokens, WARMUP_MAX_TOKENS) if isinstance(max_tokens, int) else WARMUP_MAX_TOKENS min(max_tokens, WARMUP_MAX_TOKENS) if isinstance(max_tokens, int) else WARMUP_MAX_TOKENS
) )
response = self._invoke_completion( response = self._invoke_completion(
system_prompt=PASS2_SYSTEM_PROMPT, system_prompt=SYSTEM_PROMPT,
user_prompt=_build_pass2_user_prompt_xml( user_prompt=_build_user_prompt_xml(request_payload),
request_payload,
pass1_payload={
"candidate_text": request_payload["transcript"],
"decision_spans": [],
},
pass1_error="",
),
profile=profile, profile=profile,
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
@ -373,77 +388,43 @@ class LlamaProcessor:
pass2_repeat_penalty: float | None = None, pass2_repeat_penalty: float | None = None,
pass2_min_p: float | None = None, pass2_min_p: float | None = None,
) -> tuple[str, ProcessTimings]: ) -> tuple[str, ProcessTimings]:
_ = (
pass1_temperature,
pass1_top_p,
pass1_top_k,
pass1_max_tokens,
pass1_repeat_penalty,
pass1_min_p,
pass2_temperature,
pass2_top_p,
pass2_top_k,
pass2_max_tokens,
pass2_repeat_penalty,
pass2_min_p,
)
request_payload = _build_request_payload( request_payload = _build_request_payload(
text, text,
lang=lang, lang=lang,
dictionary_context=dictionary_context, dictionary_context=dictionary_context,
) )
p1_temperature = pass1_temperature if pass1_temperature is not None else temperature
p1_top_p = pass1_top_p if pass1_top_p is not None else top_p
p1_top_k = pass1_top_k if pass1_top_k is not None else top_k
p1_max_tokens = pass1_max_tokens if pass1_max_tokens is not None else max_tokens
p1_repeat_penalty = pass1_repeat_penalty if pass1_repeat_penalty is not None else repeat_penalty
p1_min_p = pass1_min_p if pass1_min_p is not None else min_p
p2_temperature = pass2_temperature if pass2_temperature is not None else temperature
p2_top_p = pass2_top_p if pass2_top_p is not None else top_p
p2_top_k = pass2_top_k if pass2_top_k is not None else top_k
p2_max_tokens = pass2_max_tokens if pass2_max_tokens is not None else max_tokens
p2_repeat_penalty = pass2_repeat_penalty if pass2_repeat_penalty is not None else repeat_penalty
p2_min_p = pass2_min_p if pass2_min_p is not None else min_p
started_total = time.perf_counter() started_total = time.perf_counter()
response = self._invoke_completion(
started_pass1 = time.perf_counter() system_prompt=SYSTEM_PROMPT,
pass1_response = self._invoke_completion( user_prompt=_build_user_prompt_xml(request_payload),
system_prompt=PASS1_SYSTEM_PROMPT,
user_prompt=_build_pass1_user_prompt_xml(request_payload),
profile=profile, profile=profile,
temperature=p1_temperature, temperature=temperature,
top_p=p1_top_p, top_p=top_p,
top_k=p1_top_k, top_k=top_k,
max_tokens=p1_max_tokens, max_tokens=max_tokens,
repeat_penalty=p1_repeat_penalty, repeat_penalty=repeat_penalty,
min_p=p1_min_p, min_p=min_p,
adaptive_max_tokens=_recommended_analysis_max_tokens(request_payload["transcript"]),
)
pass1_ms = (time.perf_counter() - started_pass1) * 1000.0
pass1_error = ""
try:
pass1_payload = _extract_pass1_analysis(pass1_response)
except Exception as exc:
pass1_payload = {
"candidate_text": request_payload["transcript"],
"decision_spans": [],
}
pass1_error = str(exc)
started_pass2 = time.perf_counter()
pass2_response = self._invoke_completion(
system_prompt=PASS2_SYSTEM_PROMPT,
user_prompt=_build_pass2_user_prompt_xml(
request_payload,
pass1_payload=pass1_payload,
pass1_error=pass1_error,
),
profile=profile,
temperature=p2_temperature,
top_p=p2_top_p,
top_k=p2_top_k,
max_tokens=p2_max_tokens,
repeat_penalty=p2_repeat_penalty,
min_p=p2_min_p,
adaptive_max_tokens=_recommended_final_max_tokens(request_payload["transcript"], profile), adaptive_max_tokens=_recommended_final_max_tokens(request_payload["transcript"], profile),
) )
pass2_ms = (time.perf_counter() - started_pass2) * 1000.0 cleaned_text = _extract_cleaned_text(response)
cleaned_text = _extract_cleaned_text(pass2_response)
total_ms = (time.perf_counter() - started_total) * 1000.0 total_ms = (time.perf_counter() - started_total) * 1000.0
return cleaned_text, ProcessTimings( return cleaned_text, ProcessTimings(
pass1_ms=pass1_ms, pass1_ms=0.0,
pass2_ms=pass2_ms, pass2_ms=total_ms,
total_ms=total_ms, total_ms=total_ms,
) )
@ -568,17 +549,7 @@ class ExternalApiProcessor:
"model": self.model, "model": self.model,
"messages": [ "messages": [
{"role": "system", "content": SYSTEM_PROMPT}, {"role": "system", "content": SYSTEM_PROMPT},
{ {"role": "user", "content": _build_user_prompt_xml(request_payload)},
"role": "user",
"content": _build_pass2_user_prompt_xml(
request_payload,
pass1_payload={
"candidate_text": request_payload["transcript"],
"decision_spans": [],
},
pass1_error="",
),
},
], ],
"temperature": temperature if temperature is not None else 0.0, "temperature": temperature if temperature is not None else 0.0,
"response_format": {"type": "json_object"}, "response_format": {"type": "json_object"},
@ -879,7 +850,19 @@ def _build_pass2_user_prompt_xml(
# Backward-compatible helper name. # Backward-compatible helper name.
def _build_user_prompt_xml(payload: dict[str, Any]) -> str: def _build_user_prompt_xml(payload: dict[str, Any]) -> str:
return _build_pass1_user_prompt_xml(payload) language = escape(str(payload.get("language", "auto")))
transcript = escape(str(payload.get("transcript", "")))
dictionary = escape(str(payload.get("dictionary", ""))).strip()
lines = [
"<request>",
f" <language>{language}</language>",
f" <transcript>{transcript}</transcript>",
]
if dictionary:
lines.append(f" <dictionary>{dictionary}</dictionary>")
lines.append(' <output_contract>{"cleaned_text":"..."}</output_contract>')
lines.append("</request>")
return "\n".join(lines)
def _extract_pass1_analysis(payload: Any) -> dict[str, Any]: def _extract_pass1_analysis(payload: Any) -> dict[str, Any]:

View file

@ -142,6 +142,7 @@ def _process_transcript_pipeline(
stt_lang: str, stt_lang: str,
pipeline: PipelineEngine, pipeline: PipelineEngine,
suppress_ai_errors: bool, suppress_ai_errors: bool,
asr_result: AsrResult | None = None,
asr_ms: float = 0.0, asr_ms: float = 0.0,
verbose: bool = False, verbose: bool = False,
) -> tuple[str, TranscriptProcessTimings]: ) -> tuple[str, TranscriptProcessTimings]:
@ -161,7 +162,10 @@ def _process_transcript_pipeline(
total_ms=asr_ms, total_ms=asr_ms,
) )
try: try:
result = pipeline.run_transcript(processed, language=stt_lang) if asr_result is not None:
result = pipeline.run_asr_result(asr_result)
else:
result = pipeline.run_transcript(processed, language=stt_lang)
except Exception as exc: except Exception as exc:
if suppress_ai_errors: if suppress_ai_errors:
logging.error("editor stage failed: %s", exc) logging.error("editor stage failed: %s", exc)
@ -546,6 +550,7 @@ class Daemon:
stt_lang=stt_lang, stt_lang=stt_lang,
pipeline=self.pipeline, pipeline=self.pipeline,
suppress_ai_errors=False, suppress_ai_errors=False,
asr_result=asr_result,
asr_ms=asr_result.latency_ms, asr_ms=asr_result.latency_ms,
verbose=self.log_transcript, verbose=self.log_transcript,
) )

View file

@ -53,12 +53,20 @@ class PipelineEngine:
raise RuntimeError("asr stage is not configured") raise RuntimeError("asr stage is not configured")
started = time.perf_counter() started = time.perf_counter()
asr_result = self._asr_stage.transcribe(audio) asr_result = self._asr_stage.transcribe(audio)
return self.run_asr_result(asr_result, started_at=started)
def run_asr_result(
self,
asr_result: AsrResult,
*,
started_at: float | None = None,
) -> PipelineResult:
return self._run_transcript_core( return self._run_transcript_core(
asr_result.raw_text, asr_result.raw_text,
language=asr_result.language, language=asr_result.language,
asr_result=asr_result, asr_result=asr_result,
words=asr_result.words, words=asr_result.words,
started_at=started, started_at=time.perf_counter() if started_at is None else started_at,
) )
def run_transcript(self, transcript: str, *, language: str = "auto") -> PipelineResult: def run_transcript(self, transcript: str, *, language: str = "auto") -> PipelineResult:

View file

@ -186,6 +186,29 @@ class LlamaWarmupTests(unittest.TestCase):
with self.assertRaisesRegex(RuntimeError, "expected JSON"): with self.assertRaisesRegex(RuntimeError, "expected JSON"):
processor.warmup(profile="default") processor.warmup(profile="default")
def test_process_with_metrics_uses_single_completion_timing_shape(self):
processor = object.__new__(LlamaProcessor)
client = _WarmupClient(
{"choices": [{"message": {"content": '{"cleaned_text":"friday"}'}}]}
)
processor.client = client
cleaned_text, timings = processor.process_with_metrics(
"thursday, I mean friday",
lang="en",
dictionary_context="",
profile="default",
)
self.assertEqual(cleaned_text, "friday")
self.assertEqual(len(client.calls), 1)
call = client.calls[0]
self.assertEqual(call["messages"][0]["content"], aiprocess.SYSTEM_PROMPT)
self.assertIn('{"cleaned_text":"..."}', call["messages"][1]["content"])
self.assertEqual(timings.pass1_ms, 0.0)
self.assertGreater(timings.pass2_ms, 0.0)
self.assertEqual(timings.pass2_ms, timings.total_ms)
class ModelChecksumTests(unittest.TestCase): class ModelChecksumTests(unittest.TestCase):
def test_accepts_expected_checksum_case_insensitive(self): def test_accepts_expected_checksum_case_insensitive(self):

View file

@ -12,6 +12,7 @@ if str(SRC) not in sys.path:
import aman import aman
from config import Config, VocabularyReplacement from config import Config, VocabularyReplacement
from stages.asr_whisper import AsrResult, AsrSegment, AsrWord
class FakeDesktop: class FakeDesktop:
@ -144,6 +145,21 @@ class FakeStream:
self.close_calls += 1 self.close_calls += 1
def _asr_result(text: str, words: list[str], *, language: str = "auto") -> AsrResult:
asr_words: list[AsrWord] = []
start = 0.0
for token in words:
asr_words.append(AsrWord(text=token, start_s=start, end_s=start + 0.1, prob=0.9))
start += 0.2
return AsrResult(
raw_text=text,
language=language,
latency_ms=5.0,
words=asr_words,
segments=[AsrSegment(text=text, start_s=0.0, end_s=max(start, 0.1))],
)
class DaemonTests(unittest.TestCase): class DaemonTests(unittest.TestCase):
def _config(self) -> Config: def _config(self) -> Config:
cfg = Config() cfg = Config()
@ -248,6 +264,53 @@ class DaemonTests(unittest.TestCase):
self.assertEqual(desktop.inject_calls, []) self.assertEqual(desktop.inject_calls, [])
self.assertEqual(daemon.get_state(), aman.State.IDLE) self.assertEqual(daemon.get_state(), aman.State.IDLE)
@patch("aman.stop_audio_recording", return_value=FakeAudio(8))
@patch("aman.start_audio_recording", return_value=(object(), object()))
def test_live_path_uses_asr_words_for_alignment_correction(self, _start_mock, _stop_mock):
desktop = FakeDesktop()
ai_processor = FakeAIProcessor()
daemon = self._build_daemon(desktop, FakeModel(), verbose=False, ai_processor=ai_processor)
daemon.asr_stage.transcribe = lambda _audio: _asr_result(
"set alarm for 6 i mean 7",
["set", "alarm", "for", "6", "i", "mean", "7"],
language="en",
)
daemon._start_stop_worker = (
lambda stream, record, trigger, process_audio: daemon._stop_and_process(
stream, record, trigger, process_audio
)
)
daemon.toggle()
daemon.toggle()
self.assertEqual(desktop.inject_calls, [("set alarm for 7", "clipboard", False)])
self.assertEqual(ai_processor.last_kwargs.get("lang"), "en")
@patch("aman.stop_audio_recording", return_value=FakeAudio(8))
@patch("aman.start_audio_recording", return_value=(object(), object()))
def test_live_path_calls_word_aware_pipeline_entrypoint(self, _start_mock, _stop_mock):
desktop = FakeDesktop()
daemon = self._build_daemon(desktop, FakeModel(), verbose=False)
asr_result = _asr_result(
"set alarm for 6 i mean 7",
["set", "alarm", "for", "6", "i", "mean", "7"],
language="en",
)
daemon.asr_stage.transcribe = lambda _audio: asr_result
daemon._start_stop_worker = (
lambda stream, record, trigger, process_audio: daemon._stop_and_process(
stream, record, trigger, process_audio
)
)
with patch.object(daemon.pipeline, "run_asr_result", wraps=daemon.pipeline.run_asr_result) as run_asr:
daemon.toggle()
daemon.toggle()
run_asr.assert_called_once()
self.assertIs(run_asr.call_args.args[0], asr_result)
def test_transcribe_skips_hints_when_model_does_not_support_them(self): def test_transcribe_skips_hints_when_model_does_not_support_them(self):
desktop = FakeDesktop() desktop = FakeDesktop()
model = FakeModel(text="hello") model = FakeModel(text="hello")