From fa91f313c46d0db90b52b772e1b664ad837dd072 Mon Sep 17 00:00:00 2001 From: Thales Maciel Date: Thu, 12 Mar 2026 13:24:36 -0300 Subject: [PATCH] Simplify editor cleanup and keep live ASR metadata Keep the daemon path on the full ASR result so word timings and detected language survive into the editor pipeline instead of falling back to a plain transcript string. Add PipelineEngine.run_asr_result(), have aman call it when live ASR data is available, and cover the word-aware alignment behavior in the daemon tests. Collapse the llama cleanup flow to a single JSON-shaped completion while leaving the legacy pass1/pass2 parameters in place as compatibility no-ops. Validated with PYTHONPATH=src python3 -m unittest tests.test_aiprocess tests.test_aman. --- src/aiprocess.py | 147 ++++++++++++++++++---------------------- src/aman.py | 7 +- src/engine/pipeline.py | 10 ++- tests/test_aiprocess.py | 23 +++++++ tests/test_aman.py | 63 +++++++++++++++++ 5 files changed, 166 insertions(+), 84 deletions(-) diff --git a/src/aiprocess.py b/src/aiprocess.py index 093f5f2..40207d9 100644 --- a/src/aiprocess.py +++ b/src/aiprocess.py @@ -207,7 +207,29 @@ PASS2_SYSTEM_PROMPT = ( # Keep a stable symbol for documentation and tooling. -SYSTEM_PROMPT = PASS2_SYSTEM_PROMPT +SYSTEM_PROMPT = ( + "You are an amanuensis working for an user.\n" + "You'll receive a JSON object with the transcript and optional context.\n" + "Your job is to rewrite the user's transcript into clean prose.\n" + "Your output will be directly pasted in the currently focused application on the user computer.\n\n" + "Rules:\n" + "- Preserve meaning, facts, and intent.\n" + "- Preserve greetings and salutations (Hey, Hi, Hey there, Hello).\n" + "- Preserve wording. Do not replace words for synonyms\n" + "- Do not add new info.\n" + "- Remove filler words (um/uh/like)\n" + "- Remove false starts\n" + "- Remove self-corrections.\n" + "- If a dictionary section exists, apply only the listed corrections.\n" + "- Keep dictionary spellings exactly as provided.\n" + "- Treat domain hints as advisory only; never invent context-specific jargon.\n" + "- Return ONLY valid JSON in this shape: {\"cleaned_text\": \"...\"}\n" + "- Do not wrap with markdown, tags, or extra keys.\n\n" + "Examples:\n" + " - transcript=\"Hey, schedule that for 5 PM, I mean 4 PM\" -> {\"cleaned_text\":\"Hey, schedule that for 4 PM\"}\n" + " - transcript=\"Good morning Martha, nice to meet you!\" -> {\"cleaned_text\":\"Good morning Martha, nice to meet you!\"}\n" + " - transcript=\"let's ask Bob, I mean Janice, let's ask Janice\" -> {\"cleaned_text\":\"let's ask Janice\"}\n" +) class LlamaProcessor: @@ -275,15 +297,8 @@ class LlamaProcessor: min(max_tokens, WARMUP_MAX_TOKENS) if isinstance(max_tokens, int) else WARMUP_MAX_TOKENS ) response = self._invoke_completion( - system_prompt=PASS2_SYSTEM_PROMPT, - user_prompt=_build_pass2_user_prompt_xml( - request_payload, - pass1_payload={ - "candidate_text": request_payload["transcript"], - "decision_spans": [], - }, - pass1_error="", - ), + system_prompt=SYSTEM_PROMPT, + user_prompt=_build_user_prompt_xml(request_payload), profile=profile, temperature=temperature, top_p=top_p, @@ -373,77 +388,43 @@ class LlamaProcessor: pass2_repeat_penalty: float | None = None, pass2_min_p: float | None = None, ) -> tuple[str, ProcessTimings]: + _ = ( + pass1_temperature, + pass1_top_p, + pass1_top_k, + pass1_max_tokens, + pass1_repeat_penalty, + pass1_min_p, + pass2_temperature, + pass2_top_p, + pass2_top_k, + pass2_max_tokens, + pass2_repeat_penalty, + pass2_min_p, + ) request_payload = _build_request_payload( text, lang=lang, dictionary_context=dictionary_context, ) - - p1_temperature = pass1_temperature if pass1_temperature is not None else temperature - p1_top_p = pass1_top_p if pass1_top_p is not None else top_p - p1_top_k = pass1_top_k if pass1_top_k is not None else top_k - p1_max_tokens = pass1_max_tokens if pass1_max_tokens is not None else max_tokens - p1_repeat_penalty = pass1_repeat_penalty if pass1_repeat_penalty is not None else repeat_penalty - p1_min_p = pass1_min_p if pass1_min_p is not None else min_p - - p2_temperature = pass2_temperature if pass2_temperature is not None else temperature - p2_top_p = pass2_top_p if pass2_top_p is not None else top_p - p2_top_k = pass2_top_k if pass2_top_k is not None else top_k - p2_max_tokens = pass2_max_tokens if pass2_max_tokens is not None else max_tokens - p2_repeat_penalty = pass2_repeat_penalty if pass2_repeat_penalty is not None else repeat_penalty - p2_min_p = pass2_min_p if pass2_min_p is not None else min_p - started_total = time.perf_counter() - - started_pass1 = time.perf_counter() - pass1_response = self._invoke_completion( - system_prompt=PASS1_SYSTEM_PROMPT, - user_prompt=_build_pass1_user_prompt_xml(request_payload), + response = self._invoke_completion( + system_prompt=SYSTEM_PROMPT, + user_prompt=_build_user_prompt_xml(request_payload), profile=profile, - temperature=p1_temperature, - top_p=p1_top_p, - top_k=p1_top_k, - max_tokens=p1_max_tokens, - repeat_penalty=p1_repeat_penalty, - min_p=p1_min_p, - adaptive_max_tokens=_recommended_analysis_max_tokens(request_payload["transcript"]), - ) - pass1_ms = (time.perf_counter() - started_pass1) * 1000.0 - - pass1_error = "" - try: - pass1_payload = _extract_pass1_analysis(pass1_response) - except Exception as exc: - pass1_payload = { - "candidate_text": request_payload["transcript"], - "decision_spans": [], - } - pass1_error = str(exc) - - started_pass2 = time.perf_counter() - pass2_response = self._invoke_completion( - system_prompt=PASS2_SYSTEM_PROMPT, - user_prompt=_build_pass2_user_prompt_xml( - request_payload, - pass1_payload=pass1_payload, - pass1_error=pass1_error, - ), - profile=profile, - temperature=p2_temperature, - top_p=p2_top_p, - top_k=p2_top_k, - max_tokens=p2_max_tokens, - repeat_penalty=p2_repeat_penalty, - min_p=p2_min_p, + temperature=temperature, + top_p=top_p, + top_k=top_k, + max_tokens=max_tokens, + repeat_penalty=repeat_penalty, + min_p=min_p, adaptive_max_tokens=_recommended_final_max_tokens(request_payload["transcript"], profile), ) - pass2_ms = (time.perf_counter() - started_pass2) * 1000.0 - - cleaned_text = _extract_cleaned_text(pass2_response) + cleaned_text = _extract_cleaned_text(response) total_ms = (time.perf_counter() - started_total) * 1000.0 return cleaned_text, ProcessTimings( - pass1_ms=pass1_ms, - pass2_ms=pass2_ms, + pass1_ms=0.0, + pass2_ms=total_ms, total_ms=total_ms, ) @@ -568,17 +549,7 @@ class ExternalApiProcessor: "model": self.model, "messages": [ {"role": "system", "content": SYSTEM_PROMPT}, - { - "role": "user", - "content": _build_pass2_user_prompt_xml( - request_payload, - pass1_payload={ - "candidate_text": request_payload["transcript"], - "decision_spans": [], - }, - pass1_error="", - ), - }, + {"role": "user", "content": _build_user_prompt_xml(request_payload)}, ], "temperature": temperature if temperature is not None else 0.0, "response_format": {"type": "json_object"}, @@ -879,7 +850,19 @@ def _build_pass2_user_prompt_xml( # Backward-compatible helper name. def _build_user_prompt_xml(payload: dict[str, Any]) -> str: - return _build_pass1_user_prompt_xml(payload) + language = escape(str(payload.get("language", "auto"))) + transcript = escape(str(payload.get("transcript", ""))) + dictionary = escape(str(payload.get("dictionary", ""))).strip() + lines = [ + "", + f" {language}", + f" {transcript}", + ] + if dictionary: + lines.append(f" {dictionary}") + lines.append(' {"cleaned_text":"..."}') + lines.append("") + return "\n".join(lines) def _extract_pass1_analysis(payload: Any) -> dict[str, Any]: diff --git a/src/aman.py b/src/aman.py index 384f7dd..7f9d22a 100755 --- a/src/aman.py +++ b/src/aman.py @@ -142,6 +142,7 @@ def _process_transcript_pipeline( stt_lang: str, pipeline: PipelineEngine, suppress_ai_errors: bool, + asr_result: AsrResult | None = None, asr_ms: float = 0.0, verbose: bool = False, ) -> tuple[str, TranscriptProcessTimings]: @@ -161,7 +162,10 @@ def _process_transcript_pipeline( total_ms=asr_ms, ) try: - result = pipeline.run_transcript(processed, language=stt_lang) + if asr_result is not None: + result = pipeline.run_asr_result(asr_result) + else: + result = pipeline.run_transcript(processed, language=stt_lang) except Exception as exc: if suppress_ai_errors: logging.error("editor stage failed: %s", exc) @@ -546,6 +550,7 @@ class Daemon: stt_lang=stt_lang, pipeline=self.pipeline, suppress_ai_errors=False, + asr_result=asr_result, asr_ms=asr_result.latency_ms, verbose=self.log_transcript, ) diff --git a/src/engine/pipeline.py b/src/engine/pipeline.py index b138c75..306100a 100644 --- a/src/engine/pipeline.py +++ b/src/engine/pipeline.py @@ -53,12 +53,20 @@ class PipelineEngine: raise RuntimeError("asr stage is not configured") started = time.perf_counter() asr_result = self._asr_stage.transcribe(audio) + return self.run_asr_result(asr_result, started_at=started) + + def run_asr_result( + self, + asr_result: AsrResult, + *, + started_at: float | None = None, + ) -> PipelineResult: return self._run_transcript_core( asr_result.raw_text, language=asr_result.language, asr_result=asr_result, words=asr_result.words, - started_at=started, + started_at=time.perf_counter() if started_at is None else started_at, ) def run_transcript(self, transcript: str, *, language: str = "auto") -> PipelineResult: diff --git a/tests/test_aiprocess.py b/tests/test_aiprocess.py index 8872903..5e6cd18 100644 --- a/tests/test_aiprocess.py +++ b/tests/test_aiprocess.py @@ -186,6 +186,29 @@ class LlamaWarmupTests(unittest.TestCase): with self.assertRaisesRegex(RuntimeError, "expected JSON"): processor.warmup(profile="default") + def test_process_with_metrics_uses_single_completion_timing_shape(self): + processor = object.__new__(LlamaProcessor) + client = _WarmupClient( + {"choices": [{"message": {"content": '{"cleaned_text":"friday"}'}}]} + ) + processor.client = client + + cleaned_text, timings = processor.process_with_metrics( + "thursday, I mean friday", + lang="en", + dictionary_context="", + profile="default", + ) + + self.assertEqual(cleaned_text, "friday") + self.assertEqual(len(client.calls), 1) + call = client.calls[0] + self.assertEqual(call["messages"][0]["content"], aiprocess.SYSTEM_PROMPT) + self.assertIn('{"cleaned_text":"..."}', call["messages"][1]["content"]) + self.assertEqual(timings.pass1_ms, 0.0) + self.assertGreater(timings.pass2_ms, 0.0) + self.assertEqual(timings.pass2_ms, timings.total_ms) + class ModelChecksumTests(unittest.TestCase): def test_accepts_expected_checksum_case_insensitive(self): diff --git a/tests/test_aman.py b/tests/test_aman.py index e2923fe..cbf91bf 100644 --- a/tests/test_aman.py +++ b/tests/test_aman.py @@ -12,6 +12,7 @@ if str(SRC) not in sys.path: import aman from config import Config, VocabularyReplacement +from stages.asr_whisper import AsrResult, AsrSegment, AsrWord class FakeDesktop: @@ -144,6 +145,21 @@ class FakeStream: self.close_calls += 1 +def _asr_result(text: str, words: list[str], *, language: str = "auto") -> AsrResult: + asr_words: list[AsrWord] = [] + start = 0.0 + for token in words: + asr_words.append(AsrWord(text=token, start_s=start, end_s=start + 0.1, prob=0.9)) + start += 0.2 + return AsrResult( + raw_text=text, + language=language, + latency_ms=5.0, + words=asr_words, + segments=[AsrSegment(text=text, start_s=0.0, end_s=max(start, 0.1))], + ) + + class DaemonTests(unittest.TestCase): def _config(self) -> Config: cfg = Config() @@ -248,6 +264,53 @@ class DaemonTests(unittest.TestCase): self.assertEqual(desktop.inject_calls, []) self.assertEqual(daemon.get_state(), aman.State.IDLE) + @patch("aman.stop_audio_recording", return_value=FakeAudio(8)) + @patch("aman.start_audio_recording", return_value=(object(), object())) + def test_live_path_uses_asr_words_for_alignment_correction(self, _start_mock, _stop_mock): + desktop = FakeDesktop() + ai_processor = FakeAIProcessor() + daemon = self._build_daemon(desktop, FakeModel(), verbose=False, ai_processor=ai_processor) + daemon.asr_stage.transcribe = lambda _audio: _asr_result( + "set alarm for 6 i mean 7", + ["set", "alarm", "for", "6", "i", "mean", "7"], + language="en", + ) + daemon._start_stop_worker = ( + lambda stream, record, trigger, process_audio: daemon._stop_and_process( + stream, record, trigger, process_audio + ) + ) + + daemon.toggle() + daemon.toggle() + + self.assertEqual(desktop.inject_calls, [("set alarm for 7", "clipboard", False)]) + self.assertEqual(ai_processor.last_kwargs.get("lang"), "en") + + @patch("aman.stop_audio_recording", return_value=FakeAudio(8)) + @patch("aman.start_audio_recording", return_value=(object(), object())) + def test_live_path_calls_word_aware_pipeline_entrypoint(self, _start_mock, _stop_mock): + desktop = FakeDesktop() + daemon = self._build_daemon(desktop, FakeModel(), verbose=False) + asr_result = _asr_result( + "set alarm for 6 i mean 7", + ["set", "alarm", "for", "6", "i", "mean", "7"], + language="en", + ) + daemon.asr_stage.transcribe = lambda _audio: asr_result + daemon._start_stop_worker = ( + lambda stream, record, trigger, process_audio: daemon._stop_and_process( + stream, record, trigger, process_audio + ) + ) + + with patch.object(daemon.pipeline, "run_asr_result", wraps=daemon.pipeline.run_asr_result) as run_asr: + daemon.toggle() + daemon.toggle() + + run_asr.assert_called_once() + self.assertIs(run_asr.call_args.args[0], asr_result) + def test_transcribe_skips_hints_when_model_does_not_support_them(self): desktop = FakeDesktop() model = FakeModel(text="hello")