diff --git a/src/aiprocess.py b/src/aiprocess.py index 093f5f2..40207d9 100644 --- a/src/aiprocess.py +++ b/src/aiprocess.py @@ -207,7 +207,29 @@ PASS2_SYSTEM_PROMPT = ( # Keep a stable symbol for documentation and tooling. -SYSTEM_PROMPT = PASS2_SYSTEM_PROMPT +SYSTEM_PROMPT = ( + "You are an amanuensis working for an user.\n" + "You'll receive a JSON object with the transcript and optional context.\n" + "Your job is to rewrite the user's transcript into clean prose.\n" + "Your output will be directly pasted in the currently focused application on the user computer.\n\n" + "Rules:\n" + "- Preserve meaning, facts, and intent.\n" + "- Preserve greetings and salutations (Hey, Hi, Hey there, Hello).\n" + "- Preserve wording. Do not replace words for synonyms\n" + "- Do not add new info.\n" + "- Remove filler words (um/uh/like)\n" + "- Remove false starts\n" + "- Remove self-corrections.\n" + "- If a dictionary section exists, apply only the listed corrections.\n" + "- Keep dictionary spellings exactly as provided.\n" + "- Treat domain hints as advisory only; never invent context-specific jargon.\n" + "- Return ONLY valid JSON in this shape: {\"cleaned_text\": \"...\"}\n" + "- Do not wrap with markdown, tags, or extra keys.\n\n" + "Examples:\n" + " - transcript=\"Hey, schedule that for 5 PM, I mean 4 PM\" -> {\"cleaned_text\":\"Hey, schedule that for 4 PM\"}\n" + " - transcript=\"Good morning Martha, nice to meet you!\" -> {\"cleaned_text\":\"Good morning Martha, nice to meet you!\"}\n" + " - transcript=\"let's ask Bob, I mean Janice, let's ask Janice\" -> {\"cleaned_text\":\"let's ask Janice\"}\n" +) class LlamaProcessor: @@ -275,15 +297,8 @@ class LlamaProcessor: min(max_tokens, WARMUP_MAX_TOKENS) if isinstance(max_tokens, int) else WARMUP_MAX_TOKENS ) response = self._invoke_completion( - system_prompt=PASS2_SYSTEM_PROMPT, - user_prompt=_build_pass2_user_prompt_xml( - request_payload, - pass1_payload={ - "candidate_text": request_payload["transcript"], - "decision_spans": [], - }, - pass1_error="", - ), + system_prompt=SYSTEM_PROMPT, + user_prompt=_build_user_prompt_xml(request_payload), profile=profile, temperature=temperature, top_p=top_p, @@ -373,77 +388,43 @@ class LlamaProcessor: pass2_repeat_penalty: float | None = None, pass2_min_p: float | None = None, ) -> tuple[str, ProcessTimings]: + _ = ( + pass1_temperature, + pass1_top_p, + pass1_top_k, + pass1_max_tokens, + pass1_repeat_penalty, + pass1_min_p, + pass2_temperature, + pass2_top_p, + pass2_top_k, + pass2_max_tokens, + pass2_repeat_penalty, + pass2_min_p, + ) request_payload = _build_request_payload( text, lang=lang, dictionary_context=dictionary_context, ) - - p1_temperature = pass1_temperature if pass1_temperature is not None else temperature - p1_top_p = pass1_top_p if pass1_top_p is not None else top_p - p1_top_k = pass1_top_k if pass1_top_k is not None else top_k - p1_max_tokens = pass1_max_tokens if pass1_max_tokens is not None else max_tokens - p1_repeat_penalty = pass1_repeat_penalty if pass1_repeat_penalty is not None else repeat_penalty - p1_min_p = pass1_min_p if pass1_min_p is not None else min_p - - p2_temperature = pass2_temperature if pass2_temperature is not None else temperature - p2_top_p = pass2_top_p if pass2_top_p is not None else top_p - p2_top_k = pass2_top_k if pass2_top_k is not None else top_k - p2_max_tokens = pass2_max_tokens if pass2_max_tokens is not None else max_tokens - p2_repeat_penalty = pass2_repeat_penalty if pass2_repeat_penalty is not None else repeat_penalty - p2_min_p = pass2_min_p if pass2_min_p is not None else min_p - started_total = time.perf_counter() - - started_pass1 = time.perf_counter() - pass1_response = self._invoke_completion( - system_prompt=PASS1_SYSTEM_PROMPT, - user_prompt=_build_pass1_user_prompt_xml(request_payload), + response = self._invoke_completion( + system_prompt=SYSTEM_PROMPT, + user_prompt=_build_user_prompt_xml(request_payload), profile=profile, - temperature=p1_temperature, - top_p=p1_top_p, - top_k=p1_top_k, - max_tokens=p1_max_tokens, - repeat_penalty=p1_repeat_penalty, - min_p=p1_min_p, - adaptive_max_tokens=_recommended_analysis_max_tokens(request_payload["transcript"]), - ) - pass1_ms = (time.perf_counter() - started_pass1) * 1000.0 - - pass1_error = "" - try: - pass1_payload = _extract_pass1_analysis(pass1_response) - except Exception as exc: - pass1_payload = { - "candidate_text": request_payload["transcript"], - "decision_spans": [], - } - pass1_error = str(exc) - - started_pass2 = time.perf_counter() - pass2_response = self._invoke_completion( - system_prompt=PASS2_SYSTEM_PROMPT, - user_prompt=_build_pass2_user_prompt_xml( - request_payload, - pass1_payload=pass1_payload, - pass1_error=pass1_error, - ), - profile=profile, - temperature=p2_temperature, - top_p=p2_top_p, - top_k=p2_top_k, - max_tokens=p2_max_tokens, - repeat_penalty=p2_repeat_penalty, - min_p=p2_min_p, + temperature=temperature, + top_p=top_p, + top_k=top_k, + max_tokens=max_tokens, + repeat_penalty=repeat_penalty, + min_p=min_p, adaptive_max_tokens=_recommended_final_max_tokens(request_payload["transcript"], profile), ) - pass2_ms = (time.perf_counter() - started_pass2) * 1000.0 - - cleaned_text = _extract_cleaned_text(pass2_response) + cleaned_text = _extract_cleaned_text(response) total_ms = (time.perf_counter() - started_total) * 1000.0 return cleaned_text, ProcessTimings( - pass1_ms=pass1_ms, - pass2_ms=pass2_ms, + pass1_ms=0.0, + pass2_ms=total_ms, total_ms=total_ms, ) @@ -568,17 +549,7 @@ class ExternalApiProcessor: "model": self.model, "messages": [ {"role": "system", "content": SYSTEM_PROMPT}, - { - "role": "user", - "content": _build_pass2_user_prompt_xml( - request_payload, - pass1_payload={ - "candidate_text": request_payload["transcript"], - "decision_spans": [], - }, - pass1_error="", - ), - }, + {"role": "user", "content": _build_user_prompt_xml(request_payload)}, ], "temperature": temperature if temperature is not None else 0.0, "response_format": {"type": "json_object"}, @@ -879,7 +850,19 @@ def _build_pass2_user_prompt_xml( # Backward-compatible helper name. def _build_user_prompt_xml(payload: dict[str, Any]) -> str: - return _build_pass1_user_prompt_xml(payload) + language = escape(str(payload.get("language", "auto"))) + transcript = escape(str(payload.get("transcript", ""))) + dictionary = escape(str(payload.get("dictionary", ""))).strip() + lines = [ + "", + f" {language}", + f" {transcript}", + ] + if dictionary: + lines.append(f" {dictionary}") + lines.append(' {"cleaned_text":"..."}') + lines.append("") + return "\n".join(lines) def _extract_pass1_analysis(payload: Any) -> dict[str, Any]: diff --git a/src/aman.py b/src/aman.py index 384f7dd..7f9d22a 100755 --- a/src/aman.py +++ b/src/aman.py @@ -142,6 +142,7 @@ def _process_transcript_pipeline( stt_lang: str, pipeline: PipelineEngine, suppress_ai_errors: bool, + asr_result: AsrResult | None = None, asr_ms: float = 0.0, verbose: bool = False, ) -> tuple[str, TranscriptProcessTimings]: @@ -161,7 +162,10 @@ def _process_transcript_pipeline( total_ms=asr_ms, ) try: - result = pipeline.run_transcript(processed, language=stt_lang) + if asr_result is not None: + result = pipeline.run_asr_result(asr_result) + else: + result = pipeline.run_transcript(processed, language=stt_lang) except Exception as exc: if suppress_ai_errors: logging.error("editor stage failed: %s", exc) @@ -546,6 +550,7 @@ class Daemon: stt_lang=stt_lang, pipeline=self.pipeline, suppress_ai_errors=False, + asr_result=asr_result, asr_ms=asr_result.latency_ms, verbose=self.log_transcript, ) diff --git a/src/engine/pipeline.py b/src/engine/pipeline.py index b138c75..306100a 100644 --- a/src/engine/pipeline.py +++ b/src/engine/pipeline.py @@ -53,12 +53,20 @@ class PipelineEngine: raise RuntimeError("asr stage is not configured") started = time.perf_counter() asr_result = self._asr_stage.transcribe(audio) + return self.run_asr_result(asr_result, started_at=started) + + def run_asr_result( + self, + asr_result: AsrResult, + *, + started_at: float | None = None, + ) -> PipelineResult: return self._run_transcript_core( asr_result.raw_text, language=asr_result.language, asr_result=asr_result, words=asr_result.words, - started_at=started, + started_at=time.perf_counter() if started_at is None else started_at, ) def run_transcript(self, transcript: str, *, language: str = "auto") -> PipelineResult: diff --git a/tests/test_aiprocess.py b/tests/test_aiprocess.py index 8872903..5e6cd18 100644 --- a/tests/test_aiprocess.py +++ b/tests/test_aiprocess.py @@ -186,6 +186,29 @@ class LlamaWarmupTests(unittest.TestCase): with self.assertRaisesRegex(RuntimeError, "expected JSON"): processor.warmup(profile="default") + def test_process_with_metrics_uses_single_completion_timing_shape(self): + processor = object.__new__(LlamaProcessor) + client = _WarmupClient( + {"choices": [{"message": {"content": '{"cleaned_text":"friday"}'}}]} + ) + processor.client = client + + cleaned_text, timings = processor.process_with_metrics( + "thursday, I mean friday", + lang="en", + dictionary_context="", + profile="default", + ) + + self.assertEqual(cleaned_text, "friday") + self.assertEqual(len(client.calls), 1) + call = client.calls[0] + self.assertEqual(call["messages"][0]["content"], aiprocess.SYSTEM_PROMPT) + self.assertIn('{"cleaned_text":"..."}', call["messages"][1]["content"]) + self.assertEqual(timings.pass1_ms, 0.0) + self.assertGreater(timings.pass2_ms, 0.0) + self.assertEqual(timings.pass2_ms, timings.total_ms) + class ModelChecksumTests(unittest.TestCase): def test_accepts_expected_checksum_case_insensitive(self): diff --git a/tests/test_aman.py b/tests/test_aman.py index e2923fe..cbf91bf 100644 --- a/tests/test_aman.py +++ b/tests/test_aman.py @@ -12,6 +12,7 @@ if str(SRC) not in sys.path: import aman from config import Config, VocabularyReplacement +from stages.asr_whisper import AsrResult, AsrSegment, AsrWord class FakeDesktop: @@ -144,6 +145,21 @@ class FakeStream: self.close_calls += 1 +def _asr_result(text: str, words: list[str], *, language: str = "auto") -> AsrResult: + asr_words: list[AsrWord] = [] + start = 0.0 + for token in words: + asr_words.append(AsrWord(text=token, start_s=start, end_s=start + 0.1, prob=0.9)) + start += 0.2 + return AsrResult( + raw_text=text, + language=language, + latency_ms=5.0, + words=asr_words, + segments=[AsrSegment(text=text, start_s=0.0, end_s=max(start, 0.1))], + ) + + class DaemonTests(unittest.TestCase): def _config(self) -> Config: cfg = Config() @@ -248,6 +264,53 @@ class DaemonTests(unittest.TestCase): self.assertEqual(desktop.inject_calls, []) self.assertEqual(daemon.get_state(), aman.State.IDLE) + @patch("aman.stop_audio_recording", return_value=FakeAudio(8)) + @patch("aman.start_audio_recording", return_value=(object(), object())) + def test_live_path_uses_asr_words_for_alignment_correction(self, _start_mock, _stop_mock): + desktop = FakeDesktop() + ai_processor = FakeAIProcessor() + daemon = self._build_daemon(desktop, FakeModel(), verbose=False, ai_processor=ai_processor) + daemon.asr_stage.transcribe = lambda _audio: _asr_result( + "set alarm for 6 i mean 7", + ["set", "alarm", "for", "6", "i", "mean", "7"], + language="en", + ) + daemon._start_stop_worker = ( + lambda stream, record, trigger, process_audio: daemon._stop_and_process( + stream, record, trigger, process_audio + ) + ) + + daemon.toggle() + daemon.toggle() + + self.assertEqual(desktop.inject_calls, [("set alarm for 7", "clipboard", False)]) + self.assertEqual(ai_processor.last_kwargs.get("lang"), "en") + + @patch("aman.stop_audio_recording", return_value=FakeAudio(8)) + @patch("aman.start_audio_recording", return_value=(object(), object())) + def test_live_path_calls_word_aware_pipeline_entrypoint(self, _start_mock, _stop_mock): + desktop = FakeDesktop() + daemon = self._build_daemon(desktop, FakeModel(), verbose=False) + asr_result = _asr_result( + "set alarm for 6 i mean 7", + ["set", "alarm", "for", "6", "i", "mean", "7"], + language="en", + ) + daemon.asr_stage.transcribe = lambda _audio: asr_result + daemon._start_stop_worker = ( + lambda stream, record, trigger, process_audio: daemon._stop_and_process( + stream, record, trigger, process_audio + ) + ) + + with patch.object(daemon.pipeline, "run_asr_result", wraps=daemon.pipeline.run_asr_result) as run_asr: + daemon.toggle() + daemon.toggle() + + run_asr.assert_called_once() + self.assertIs(run_asr.call_args.args[0], asr_result) + def test_transcribe_skips_hints_when_model_does_not_support_them(self): desktop = FakeDesktop() model = FakeModel(text="hello")