Remove log_transcript config and enforce JSON AI output
This commit is contained in:
parent
c3503fbbde
commit
1423e44008
8 changed files with 198 additions and 62 deletions
|
|
@ -1,6 +1,8 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import ctypes
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
|
@ -12,7 +14,7 @@ from constants import MODEL_DIR, MODEL_NAME, MODEL_PATH, MODEL_URL
|
|||
|
||||
SYSTEM_PROMPT = (
|
||||
"You are an amanuensis working for an user.\n"
|
||||
"You'll receive an accurate transcript of what the user said.\n"
|
||||
"You'll receive a JSON object with the transcript and optional context.\n"
|
||||
"Your job is to rewrite the user's transcript into clean prose.\n"
|
||||
"Your output will be directly pasted in the currently focused application on the user computer.\n\n"
|
||||
|
||||
|
|
@ -24,14 +26,15 @@ SYSTEM_PROMPT = (
|
|||
"- Remove filler words (um/uh/like)\n"
|
||||
"- Remove false starts\n"
|
||||
"- Remove self-corrections.\n"
|
||||
"- If a <dictionary> section exists, apply only the listed corrections.\n"
|
||||
"- If a dictionary section exists, apply only the listed corrections.\n"
|
||||
"- Keep dictionary spellings exactly as provided.\n"
|
||||
"- Treat domain hints as advisory only; never invent context-specific jargon.\n"
|
||||
"- Output ONLY the cleaned text, no commentary.\n\n"
|
||||
"- Return ONLY valid JSON in this shape: {\"cleaned_text\": \"...\"}\n"
|
||||
"- Do not wrap with markdown, tags, or extra keys.\n\n"
|
||||
"Examples:\n"
|
||||
" - \"Hey, schedule that for 5 PM, I mean 4 PM\" -> \"Hey, schedule that for 4 PM\"\n"
|
||||
" - \"Good morning Martha, nice to meet you!\" -> \"Good morning Martha, nice to meet you!\"\n"
|
||||
" - \"let's ask Bob, I mean Janice, let's ask Janice\" -> \"let's ask Janice\"\n"
|
||||
" - transcript=\"Hey, schedule that for 5 PM, I mean 4 PM\" -> {\"cleaned_text\":\"Hey, schedule that for 4 PM\"}\n"
|
||||
" - transcript=\"Good morning Martha, nice to meet you!\" -> {\"cleaned_text\":\"Good morning Martha, nice to meet you!\"}\n"
|
||||
" - transcript=\"let's ask Bob, I mean Janice, let's ask Janice\" -> {\"cleaned_text\":\"let's ask Janice\"}\n"
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -61,22 +64,30 @@ class LlamaProcessor:
|
|||
domain_name: str = "general",
|
||||
domain_confidence: float = 0.0,
|
||||
) -> str:
|
||||
blocks = [
|
||||
f"<language>{lang}</language>",
|
||||
f'<domain name="{domain_name}" confidence="{domain_confidence:.2f}"/>',
|
||||
]
|
||||
if dictionary_context.strip():
|
||||
blocks.append(f"<dictionary>\n{dictionary_context.strip()}\n</dictionary>")
|
||||
blocks.append(f"<transcript>{text}</transcript>")
|
||||
user_content = "\n".join(blocks)
|
||||
response = self.client.create_chat_completion(
|
||||
messages=[
|
||||
request_payload: dict[str, Any] = {
|
||||
"language": lang,
|
||||
"domain": {
|
||||
"name": domain_name,
|
||||
"confidence": round(float(domain_confidence), 2),
|
||||
},
|
||||
"transcript": text,
|
||||
}
|
||||
cleaned_dictionary = dictionary_context.strip()
|
||||
if cleaned_dictionary:
|
||||
request_payload["dictionary"] = cleaned_dictionary
|
||||
|
||||
kwargs: dict[str, Any] = {
|
||||
"messages": [
|
||||
{"role": "system", "content": SYSTEM_PROMPT},
|
||||
{"role": "user", "content": user_content},
|
||||
{"role": "user", "content": json.dumps(request_payload, ensure_ascii=False)},
|
||||
],
|
||||
temperature=0.0,
|
||||
)
|
||||
return _extract_chat_text(response)
|
||||
"temperature": 0.0,
|
||||
}
|
||||
if _supports_response_format(self.client.create_chat_completion):
|
||||
kwargs["response_format"] = {"type": "json_object"}
|
||||
|
||||
response = self.client.create_chat_completion(**kwargs)
|
||||
return _extract_cleaned_text(response)
|
||||
|
||||
|
||||
def ensure_model():
|
||||
|
|
@ -135,6 +146,32 @@ def _extract_chat_text(payload: Any) -> str:
|
|||
raise RuntimeError("unexpected response format")
|
||||
|
||||
|
||||
def _extract_cleaned_text(payload: Any) -> str:
|
||||
raw = _extract_chat_text(payload)
|
||||
try:
|
||||
parsed = json.loads(raw)
|
||||
except json.JSONDecodeError as exc:
|
||||
raise RuntimeError("unexpected ai output format: expected JSON") from exc
|
||||
|
||||
if isinstance(parsed, str):
|
||||
return parsed
|
||||
|
||||
if isinstance(parsed, dict):
|
||||
cleaned_text = parsed.get("cleaned_text")
|
||||
if isinstance(cleaned_text, str):
|
||||
return cleaned_text
|
||||
|
||||
raise RuntimeError("unexpected ai output format: missing cleaned_text")
|
||||
|
||||
|
||||
def _supports_response_format(chat_completion: Callable[..., Any]) -> bool:
|
||||
try:
|
||||
signature = inspect.signature(chat_completion)
|
||||
except (TypeError, ValueError):
|
||||
return False
|
||||
return "response_format" in signature.parameters
|
||||
|
||||
|
||||
def _llama_log_callback_factory(verbose: bool) -> Callable:
|
||||
callback_t = ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_char_p, ctypes.c_void_p)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue