Simplify editor cleanup and keep live ASR metadata
Some checks are pending
ci / test-and-build (push) Waiting to run
Some checks are pending
ci / test-and-build (push) Waiting to run
Keep the daemon path on the full ASR result so word timings and detected language survive into the editor pipeline instead of falling back to a plain transcript string. Add PipelineEngine.run_asr_result(), have aman call it when live ASR data is available, and cover the word-aware alignment behavior in the daemon tests. Collapse the llama cleanup flow to a single JSON-shaped completion while leaving the legacy pass1/pass2 parameters in place as compatibility no-ops. Validated with PYTHONPATH=src python3 -m unittest tests.test_aiprocess tests.test_aman.
This commit is contained in:
parent
8c1f7c1e13
commit
fa91f313c4
5 changed files with 166 additions and 84 deletions
147
src/aiprocess.py
147
src/aiprocess.py
|
|
@ -207,7 +207,29 @@ PASS2_SYSTEM_PROMPT = (
|
||||||
|
|
||||||
|
|
||||||
# Keep a stable symbol for documentation and tooling.
|
# Keep a stable symbol for documentation and tooling.
|
||||||
SYSTEM_PROMPT = PASS2_SYSTEM_PROMPT
|
SYSTEM_PROMPT = (
|
||||||
|
"You are an amanuensis working for an user.\n"
|
||||||
|
"You'll receive a JSON object with the transcript and optional context.\n"
|
||||||
|
"Your job is to rewrite the user's transcript into clean prose.\n"
|
||||||
|
"Your output will be directly pasted in the currently focused application on the user computer.\n\n"
|
||||||
|
"Rules:\n"
|
||||||
|
"- Preserve meaning, facts, and intent.\n"
|
||||||
|
"- Preserve greetings and salutations (Hey, Hi, Hey there, Hello).\n"
|
||||||
|
"- Preserve wording. Do not replace words for synonyms\n"
|
||||||
|
"- Do not add new info.\n"
|
||||||
|
"- Remove filler words (um/uh/like)\n"
|
||||||
|
"- Remove false starts\n"
|
||||||
|
"- Remove self-corrections.\n"
|
||||||
|
"- If a dictionary section exists, apply only the listed corrections.\n"
|
||||||
|
"- Keep dictionary spellings exactly as provided.\n"
|
||||||
|
"- Treat domain hints as advisory only; never invent context-specific jargon.\n"
|
||||||
|
"- Return ONLY valid JSON in this shape: {\"cleaned_text\": \"...\"}\n"
|
||||||
|
"- Do not wrap with markdown, tags, or extra keys.\n\n"
|
||||||
|
"Examples:\n"
|
||||||
|
" - transcript=\"Hey, schedule that for 5 PM, I mean 4 PM\" -> {\"cleaned_text\":\"Hey, schedule that for 4 PM\"}\n"
|
||||||
|
" - transcript=\"Good morning Martha, nice to meet you!\" -> {\"cleaned_text\":\"Good morning Martha, nice to meet you!\"}\n"
|
||||||
|
" - transcript=\"let's ask Bob, I mean Janice, let's ask Janice\" -> {\"cleaned_text\":\"let's ask Janice\"}\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class LlamaProcessor:
|
class LlamaProcessor:
|
||||||
|
|
@ -275,15 +297,8 @@ class LlamaProcessor:
|
||||||
min(max_tokens, WARMUP_MAX_TOKENS) if isinstance(max_tokens, int) else WARMUP_MAX_TOKENS
|
min(max_tokens, WARMUP_MAX_TOKENS) if isinstance(max_tokens, int) else WARMUP_MAX_TOKENS
|
||||||
)
|
)
|
||||||
response = self._invoke_completion(
|
response = self._invoke_completion(
|
||||||
system_prompt=PASS2_SYSTEM_PROMPT,
|
system_prompt=SYSTEM_PROMPT,
|
||||||
user_prompt=_build_pass2_user_prompt_xml(
|
user_prompt=_build_user_prompt_xml(request_payload),
|
||||||
request_payload,
|
|
||||||
pass1_payload={
|
|
||||||
"candidate_text": request_payload["transcript"],
|
|
||||||
"decision_spans": [],
|
|
||||||
},
|
|
||||||
pass1_error="",
|
|
||||||
),
|
|
||||||
profile=profile,
|
profile=profile,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
|
|
@ -373,77 +388,43 @@ class LlamaProcessor:
|
||||||
pass2_repeat_penalty: float | None = None,
|
pass2_repeat_penalty: float | None = None,
|
||||||
pass2_min_p: float | None = None,
|
pass2_min_p: float | None = None,
|
||||||
) -> tuple[str, ProcessTimings]:
|
) -> tuple[str, ProcessTimings]:
|
||||||
|
_ = (
|
||||||
|
pass1_temperature,
|
||||||
|
pass1_top_p,
|
||||||
|
pass1_top_k,
|
||||||
|
pass1_max_tokens,
|
||||||
|
pass1_repeat_penalty,
|
||||||
|
pass1_min_p,
|
||||||
|
pass2_temperature,
|
||||||
|
pass2_top_p,
|
||||||
|
pass2_top_k,
|
||||||
|
pass2_max_tokens,
|
||||||
|
pass2_repeat_penalty,
|
||||||
|
pass2_min_p,
|
||||||
|
)
|
||||||
request_payload = _build_request_payload(
|
request_payload = _build_request_payload(
|
||||||
text,
|
text,
|
||||||
lang=lang,
|
lang=lang,
|
||||||
dictionary_context=dictionary_context,
|
dictionary_context=dictionary_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
p1_temperature = pass1_temperature if pass1_temperature is not None else temperature
|
|
||||||
p1_top_p = pass1_top_p if pass1_top_p is not None else top_p
|
|
||||||
p1_top_k = pass1_top_k if pass1_top_k is not None else top_k
|
|
||||||
p1_max_tokens = pass1_max_tokens if pass1_max_tokens is not None else max_tokens
|
|
||||||
p1_repeat_penalty = pass1_repeat_penalty if pass1_repeat_penalty is not None else repeat_penalty
|
|
||||||
p1_min_p = pass1_min_p if pass1_min_p is not None else min_p
|
|
||||||
|
|
||||||
p2_temperature = pass2_temperature if pass2_temperature is not None else temperature
|
|
||||||
p2_top_p = pass2_top_p if pass2_top_p is not None else top_p
|
|
||||||
p2_top_k = pass2_top_k if pass2_top_k is not None else top_k
|
|
||||||
p2_max_tokens = pass2_max_tokens if pass2_max_tokens is not None else max_tokens
|
|
||||||
p2_repeat_penalty = pass2_repeat_penalty if pass2_repeat_penalty is not None else repeat_penalty
|
|
||||||
p2_min_p = pass2_min_p if pass2_min_p is not None else min_p
|
|
||||||
|
|
||||||
started_total = time.perf_counter()
|
started_total = time.perf_counter()
|
||||||
|
response = self._invoke_completion(
|
||||||
started_pass1 = time.perf_counter()
|
system_prompt=SYSTEM_PROMPT,
|
||||||
pass1_response = self._invoke_completion(
|
user_prompt=_build_user_prompt_xml(request_payload),
|
||||||
system_prompt=PASS1_SYSTEM_PROMPT,
|
|
||||||
user_prompt=_build_pass1_user_prompt_xml(request_payload),
|
|
||||||
profile=profile,
|
profile=profile,
|
||||||
temperature=p1_temperature,
|
temperature=temperature,
|
||||||
top_p=p1_top_p,
|
top_p=top_p,
|
||||||
top_k=p1_top_k,
|
top_k=top_k,
|
||||||
max_tokens=p1_max_tokens,
|
max_tokens=max_tokens,
|
||||||
repeat_penalty=p1_repeat_penalty,
|
repeat_penalty=repeat_penalty,
|
||||||
min_p=p1_min_p,
|
min_p=min_p,
|
||||||
adaptive_max_tokens=_recommended_analysis_max_tokens(request_payload["transcript"]),
|
|
||||||
)
|
|
||||||
pass1_ms = (time.perf_counter() - started_pass1) * 1000.0
|
|
||||||
|
|
||||||
pass1_error = ""
|
|
||||||
try:
|
|
||||||
pass1_payload = _extract_pass1_analysis(pass1_response)
|
|
||||||
except Exception as exc:
|
|
||||||
pass1_payload = {
|
|
||||||
"candidate_text": request_payload["transcript"],
|
|
||||||
"decision_spans": [],
|
|
||||||
}
|
|
||||||
pass1_error = str(exc)
|
|
||||||
|
|
||||||
started_pass2 = time.perf_counter()
|
|
||||||
pass2_response = self._invoke_completion(
|
|
||||||
system_prompt=PASS2_SYSTEM_PROMPT,
|
|
||||||
user_prompt=_build_pass2_user_prompt_xml(
|
|
||||||
request_payload,
|
|
||||||
pass1_payload=pass1_payload,
|
|
||||||
pass1_error=pass1_error,
|
|
||||||
),
|
|
||||||
profile=profile,
|
|
||||||
temperature=p2_temperature,
|
|
||||||
top_p=p2_top_p,
|
|
||||||
top_k=p2_top_k,
|
|
||||||
max_tokens=p2_max_tokens,
|
|
||||||
repeat_penalty=p2_repeat_penalty,
|
|
||||||
min_p=p2_min_p,
|
|
||||||
adaptive_max_tokens=_recommended_final_max_tokens(request_payload["transcript"], profile),
|
adaptive_max_tokens=_recommended_final_max_tokens(request_payload["transcript"], profile),
|
||||||
)
|
)
|
||||||
pass2_ms = (time.perf_counter() - started_pass2) * 1000.0
|
cleaned_text = _extract_cleaned_text(response)
|
||||||
|
|
||||||
cleaned_text = _extract_cleaned_text(pass2_response)
|
|
||||||
total_ms = (time.perf_counter() - started_total) * 1000.0
|
total_ms = (time.perf_counter() - started_total) * 1000.0
|
||||||
return cleaned_text, ProcessTimings(
|
return cleaned_text, ProcessTimings(
|
||||||
pass1_ms=pass1_ms,
|
pass1_ms=0.0,
|
||||||
pass2_ms=pass2_ms,
|
pass2_ms=total_ms,
|
||||||
total_ms=total_ms,
|
total_ms=total_ms,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -568,17 +549,7 @@ class ExternalApiProcessor:
|
||||||
"model": self.model,
|
"model": self.model,
|
||||||
"messages": [
|
"messages": [
|
||||||
{"role": "system", "content": SYSTEM_PROMPT},
|
{"role": "system", "content": SYSTEM_PROMPT},
|
||||||
{
|
{"role": "user", "content": _build_user_prompt_xml(request_payload)},
|
||||||
"role": "user",
|
|
||||||
"content": _build_pass2_user_prompt_xml(
|
|
||||||
request_payload,
|
|
||||||
pass1_payload={
|
|
||||||
"candidate_text": request_payload["transcript"],
|
|
||||||
"decision_spans": [],
|
|
||||||
},
|
|
||||||
pass1_error="",
|
|
||||||
),
|
|
||||||
},
|
|
||||||
],
|
],
|
||||||
"temperature": temperature if temperature is not None else 0.0,
|
"temperature": temperature if temperature is not None else 0.0,
|
||||||
"response_format": {"type": "json_object"},
|
"response_format": {"type": "json_object"},
|
||||||
|
|
@ -879,7 +850,19 @@ def _build_pass2_user_prompt_xml(
|
||||||
|
|
||||||
# Backward-compatible helper name.
|
# Backward-compatible helper name.
|
||||||
def _build_user_prompt_xml(payload: dict[str, Any]) -> str:
|
def _build_user_prompt_xml(payload: dict[str, Any]) -> str:
|
||||||
return _build_pass1_user_prompt_xml(payload)
|
language = escape(str(payload.get("language", "auto")))
|
||||||
|
transcript = escape(str(payload.get("transcript", "")))
|
||||||
|
dictionary = escape(str(payload.get("dictionary", ""))).strip()
|
||||||
|
lines = [
|
||||||
|
"<request>",
|
||||||
|
f" <language>{language}</language>",
|
||||||
|
f" <transcript>{transcript}</transcript>",
|
||||||
|
]
|
||||||
|
if dictionary:
|
||||||
|
lines.append(f" <dictionary>{dictionary}</dictionary>")
|
||||||
|
lines.append(' <output_contract>{"cleaned_text":"..."}</output_contract>')
|
||||||
|
lines.append("</request>")
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
def _extract_pass1_analysis(payload: Any) -> dict[str, Any]:
|
def _extract_pass1_analysis(payload: Any) -> dict[str, Any]:
|
||||||
|
|
|
||||||
|
|
@ -142,6 +142,7 @@ def _process_transcript_pipeline(
|
||||||
stt_lang: str,
|
stt_lang: str,
|
||||||
pipeline: PipelineEngine,
|
pipeline: PipelineEngine,
|
||||||
suppress_ai_errors: bool,
|
suppress_ai_errors: bool,
|
||||||
|
asr_result: AsrResult | None = None,
|
||||||
asr_ms: float = 0.0,
|
asr_ms: float = 0.0,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
) -> tuple[str, TranscriptProcessTimings]:
|
) -> tuple[str, TranscriptProcessTimings]:
|
||||||
|
|
@ -161,7 +162,10 @@ def _process_transcript_pipeline(
|
||||||
total_ms=asr_ms,
|
total_ms=asr_ms,
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
result = pipeline.run_transcript(processed, language=stt_lang)
|
if asr_result is not None:
|
||||||
|
result = pipeline.run_asr_result(asr_result)
|
||||||
|
else:
|
||||||
|
result = pipeline.run_transcript(processed, language=stt_lang)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
if suppress_ai_errors:
|
if suppress_ai_errors:
|
||||||
logging.error("editor stage failed: %s", exc)
|
logging.error("editor stage failed: %s", exc)
|
||||||
|
|
@ -546,6 +550,7 @@ class Daemon:
|
||||||
stt_lang=stt_lang,
|
stt_lang=stt_lang,
|
||||||
pipeline=self.pipeline,
|
pipeline=self.pipeline,
|
||||||
suppress_ai_errors=False,
|
suppress_ai_errors=False,
|
||||||
|
asr_result=asr_result,
|
||||||
asr_ms=asr_result.latency_ms,
|
asr_ms=asr_result.latency_ms,
|
||||||
verbose=self.log_transcript,
|
verbose=self.log_transcript,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -53,12 +53,20 @@ class PipelineEngine:
|
||||||
raise RuntimeError("asr stage is not configured")
|
raise RuntimeError("asr stage is not configured")
|
||||||
started = time.perf_counter()
|
started = time.perf_counter()
|
||||||
asr_result = self._asr_stage.transcribe(audio)
|
asr_result = self._asr_stage.transcribe(audio)
|
||||||
|
return self.run_asr_result(asr_result, started_at=started)
|
||||||
|
|
||||||
|
def run_asr_result(
|
||||||
|
self,
|
||||||
|
asr_result: AsrResult,
|
||||||
|
*,
|
||||||
|
started_at: float | None = None,
|
||||||
|
) -> PipelineResult:
|
||||||
return self._run_transcript_core(
|
return self._run_transcript_core(
|
||||||
asr_result.raw_text,
|
asr_result.raw_text,
|
||||||
language=asr_result.language,
|
language=asr_result.language,
|
||||||
asr_result=asr_result,
|
asr_result=asr_result,
|
||||||
words=asr_result.words,
|
words=asr_result.words,
|
||||||
started_at=started,
|
started_at=time.perf_counter() if started_at is None else started_at,
|
||||||
)
|
)
|
||||||
|
|
||||||
def run_transcript(self, transcript: str, *, language: str = "auto") -> PipelineResult:
|
def run_transcript(self, transcript: str, *, language: str = "auto") -> PipelineResult:
|
||||||
|
|
|
||||||
|
|
@ -186,6 +186,29 @@ class LlamaWarmupTests(unittest.TestCase):
|
||||||
with self.assertRaisesRegex(RuntimeError, "expected JSON"):
|
with self.assertRaisesRegex(RuntimeError, "expected JSON"):
|
||||||
processor.warmup(profile="default")
|
processor.warmup(profile="default")
|
||||||
|
|
||||||
|
def test_process_with_metrics_uses_single_completion_timing_shape(self):
|
||||||
|
processor = object.__new__(LlamaProcessor)
|
||||||
|
client = _WarmupClient(
|
||||||
|
{"choices": [{"message": {"content": '{"cleaned_text":"friday"}'}}]}
|
||||||
|
)
|
||||||
|
processor.client = client
|
||||||
|
|
||||||
|
cleaned_text, timings = processor.process_with_metrics(
|
||||||
|
"thursday, I mean friday",
|
||||||
|
lang="en",
|
||||||
|
dictionary_context="",
|
||||||
|
profile="default",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(cleaned_text, "friday")
|
||||||
|
self.assertEqual(len(client.calls), 1)
|
||||||
|
call = client.calls[0]
|
||||||
|
self.assertEqual(call["messages"][0]["content"], aiprocess.SYSTEM_PROMPT)
|
||||||
|
self.assertIn('{"cleaned_text":"..."}', call["messages"][1]["content"])
|
||||||
|
self.assertEqual(timings.pass1_ms, 0.0)
|
||||||
|
self.assertGreater(timings.pass2_ms, 0.0)
|
||||||
|
self.assertEqual(timings.pass2_ms, timings.total_ms)
|
||||||
|
|
||||||
|
|
||||||
class ModelChecksumTests(unittest.TestCase):
|
class ModelChecksumTests(unittest.TestCase):
|
||||||
def test_accepts_expected_checksum_case_insensitive(self):
|
def test_accepts_expected_checksum_case_insensitive(self):
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,7 @@ if str(SRC) not in sys.path:
|
||||||
|
|
||||||
import aman
|
import aman
|
||||||
from config import Config, VocabularyReplacement
|
from config import Config, VocabularyReplacement
|
||||||
|
from stages.asr_whisper import AsrResult, AsrSegment, AsrWord
|
||||||
|
|
||||||
|
|
||||||
class FakeDesktop:
|
class FakeDesktop:
|
||||||
|
|
@ -144,6 +145,21 @@ class FakeStream:
|
||||||
self.close_calls += 1
|
self.close_calls += 1
|
||||||
|
|
||||||
|
|
||||||
|
def _asr_result(text: str, words: list[str], *, language: str = "auto") -> AsrResult:
|
||||||
|
asr_words: list[AsrWord] = []
|
||||||
|
start = 0.0
|
||||||
|
for token in words:
|
||||||
|
asr_words.append(AsrWord(text=token, start_s=start, end_s=start + 0.1, prob=0.9))
|
||||||
|
start += 0.2
|
||||||
|
return AsrResult(
|
||||||
|
raw_text=text,
|
||||||
|
language=language,
|
||||||
|
latency_ms=5.0,
|
||||||
|
words=asr_words,
|
||||||
|
segments=[AsrSegment(text=text, start_s=0.0, end_s=max(start, 0.1))],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class DaemonTests(unittest.TestCase):
|
class DaemonTests(unittest.TestCase):
|
||||||
def _config(self) -> Config:
|
def _config(self) -> Config:
|
||||||
cfg = Config()
|
cfg = Config()
|
||||||
|
|
@ -248,6 +264,53 @@ class DaemonTests(unittest.TestCase):
|
||||||
self.assertEqual(desktop.inject_calls, [])
|
self.assertEqual(desktop.inject_calls, [])
|
||||||
self.assertEqual(daemon.get_state(), aman.State.IDLE)
|
self.assertEqual(daemon.get_state(), aman.State.IDLE)
|
||||||
|
|
||||||
|
@patch("aman.stop_audio_recording", return_value=FakeAudio(8))
|
||||||
|
@patch("aman.start_audio_recording", return_value=(object(), object()))
|
||||||
|
def test_live_path_uses_asr_words_for_alignment_correction(self, _start_mock, _stop_mock):
|
||||||
|
desktop = FakeDesktop()
|
||||||
|
ai_processor = FakeAIProcessor()
|
||||||
|
daemon = self._build_daemon(desktop, FakeModel(), verbose=False, ai_processor=ai_processor)
|
||||||
|
daemon.asr_stage.transcribe = lambda _audio: _asr_result(
|
||||||
|
"set alarm for 6 i mean 7",
|
||||||
|
["set", "alarm", "for", "6", "i", "mean", "7"],
|
||||||
|
language="en",
|
||||||
|
)
|
||||||
|
daemon._start_stop_worker = (
|
||||||
|
lambda stream, record, trigger, process_audio: daemon._stop_and_process(
|
||||||
|
stream, record, trigger, process_audio
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
daemon.toggle()
|
||||||
|
daemon.toggle()
|
||||||
|
|
||||||
|
self.assertEqual(desktop.inject_calls, [("set alarm for 7", "clipboard", False)])
|
||||||
|
self.assertEqual(ai_processor.last_kwargs.get("lang"), "en")
|
||||||
|
|
||||||
|
@patch("aman.stop_audio_recording", return_value=FakeAudio(8))
|
||||||
|
@patch("aman.start_audio_recording", return_value=(object(), object()))
|
||||||
|
def test_live_path_calls_word_aware_pipeline_entrypoint(self, _start_mock, _stop_mock):
|
||||||
|
desktop = FakeDesktop()
|
||||||
|
daemon = self._build_daemon(desktop, FakeModel(), verbose=False)
|
||||||
|
asr_result = _asr_result(
|
||||||
|
"set alarm for 6 i mean 7",
|
||||||
|
["set", "alarm", "for", "6", "i", "mean", "7"],
|
||||||
|
language="en",
|
||||||
|
)
|
||||||
|
daemon.asr_stage.transcribe = lambda _audio: asr_result
|
||||||
|
daemon._start_stop_worker = (
|
||||||
|
lambda stream, record, trigger, process_audio: daemon._stop_and_process(
|
||||||
|
stream, record, trigger, process_audio
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(daemon.pipeline, "run_asr_result", wraps=daemon.pipeline.run_asr_result) as run_asr:
|
||||||
|
daemon.toggle()
|
||||||
|
daemon.toggle()
|
||||||
|
|
||||||
|
run_asr.assert_called_once()
|
||||||
|
self.assertIs(run_asr.call_args.args[0], asr_result)
|
||||||
|
|
||||||
def test_transcribe_skips_hints_when_model_does_not_support_them(self):
|
def test_transcribe_skips_hints_when_model_does_not_support_them(self):
|
||||||
desktop = FakeDesktop()
|
desktop = FakeDesktop()
|
||||||
model = FakeModel(text="hello")
|
model = FakeModel(text="hello")
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue