Remove log_transcript config and enforce JSON AI output

This commit is contained in:
Thales Maciel 2026-02-25 10:23:56 -03:00
parent c3503fbbde
commit 1423e44008
8 changed files with 198 additions and 62 deletions

View file

@ -1,6 +1,8 @@
from __future__ import annotations
import ctypes
import inspect
import json
import logging
import os
import sys
@ -12,7 +14,7 @@ from constants import MODEL_DIR, MODEL_NAME, MODEL_PATH, MODEL_URL
SYSTEM_PROMPT = (
"You are an amanuensis working for an user.\n"
"You'll receive an accurate transcript of what the user said.\n"
"You'll receive a JSON object with the transcript and optional context.\n"
"Your job is to rewrite the user's transcript into clean prose.\n"
"Your output will be directly pasted in the currently focused application on the user computer.\n\n"
@ -24,14 +26,15 @@ SYSTEM_PROMPT = (
"- Remove filler words (um/uh/like)\n"
"- Remove false starts\n"
"- Remove self-corrections.\n"
"- If a <dictionary> section exists, apply only the listed corrections.\n"
"- If a dictionary section exists, apply only the listed corrections.\n"
"- Keep dictionary spellings exactly as provided.\n"
"- Treat domain hints as advisory only; never invent context-specific jargon.\n"
"- Output ONLY the cleaned text, no commentary.\n\n"
"- Return ONLY valid JSON in this shape: {\"cleaned_text\": \"...\"}\n"
"- Do not wrap with markdown, tags, or extra keys.\n\n"
"Examples:\n"
" - \"Hey, schedule that for 5 PM, I mean 4 PM\" -> \"Hey, schedule that for 4 PM\"\n"
" - \"Good morning Martha, nice to meet you!\" -> \"Good morning Martha, nice to meet you!\"\n"
" - \"let's ask Bob, I mean Janice, let's ask Janice\" -> \"let's ask Janice\"\n"
" - transcript=\"Hey, schedule that for 5 PM, I mean 4 PM\" -> {\"cleaned_text\":\"Hey, schedule that for 4 PM\"}\n"
" - transcript=\"Good morning Martha, nice to meet you!\" -> {\"cleaned_text\":\"Good morning Martha, nice to meet you!\"}\n"
" - transcript=\"let's ask Bob, I mean Janice, let's ask Janice\" -> {\"cleaned_text\":\"let's ask Janice\"}\n"
)
@ -61,22 +64,30 @@ class LlamaProcessor:
domain_name: str = "general",
domain_confidence: float = 0.0,
) -> str:
blocks = [
f"<language>{lang}</language>",
f'<domain name="{domain_name}" confidence="{domain_confidence:.2f}"/>',
]
if dictionary_context.strip():
blocks.append(f"<dictionary>\n{dictionary_context.strip()}\n</dictionary>")
blocks.append(f"<transcript>{text}</transcript>")
user_content = "\n".join(blocks)
response = self.client.create_chat_completion(
messages=[
request_payload: dict[str, Any] = {
"language": lang,
"domain": {
"name": domain_name,
"confidence": round(float(domain_confidence), 2),
},
"transcript": text,
}
cleaned_dictionary = dictionary_context.strip()
if cleaned_dictionary:
request_payload["dictionary"] = cleaned_dictionary
kwargs: dict[str, Any] = {
"messages": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_content},
{"role": "user", "content": json.dumps(request_payload, ensure_ascii=False)},
],
temperature=0.0,
)
return _extract_chat_text(response)
"temperature": 0.0,
}
if _supports_response_format(self.client.create_chat_completion):
kwargs["response_format"] = {"type": "json_object"}
response = self.client.create_chat_completion(**kwargs)
return _extract_cleaned_text(response)
def ensure_model():
@ -135,6 +146,32 @@ def _extract_chat_text(payload: Any) -> str:
raise RuntimeError("unexpected response format")
def _extract_cleaned_text(payload: Any) -> str:
raw = _extract_chat_text(payload)
try:
parsed = json.loads(raw)
except json.JSONDecodeError as exc:
raise RuntimeError("unexpected ai output format: expected JSON") from exc
if isinstance(parsed, str):
return parsed
if isinstance(parsed, dict):
cleaned_text = parsed.get("cleaned_text")
if isinstance(cleaned_text, str):
return cleaned_text
raise RuntimeError("unexpected ai output format: missing cleaned_text")
def _supports_response_format(chat_completion: Callable[..., Any]) -> bool:
try:
signature = inspect.signature(chat_completion)
except (TypeError, ValueError):
return False
return "response_format" in signature.parameters
def _llama_log_callback_factory(verbose: bool) -> Callable:
callback_t = ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_char_p, ctypes.c_void_p)