Remove log_transcript config and enforce JSON AI output
This commit is contained in:
parent
c3503fbbde
commit
1423e44008
8 changed files with 198 additions and 62 deletions
|
|
@ -1,6 +1,8 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import ctypes
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
|
@ -12,7 +14,7 @@ from constants import MODEL_DIR, MODEL_NAME, MODEL_PATH, MODEL_URL
|
|||
|
||||
SYSTEM_PROMPT = (
|
||||
"You are an amanuensis working for an user.\n"
|
||||
"You'll receive an accurate transcript of what the user said.\n"
|
||||
"You'll receive a JSON object with the transcript and optional context.\n"
|
||||
"Your job is to rewrite the user's transcript into clean prose.\n"
|
||||
"Your output will be directly pasted in the currently focused application on the user computer.\n\n"
|
||||
|
||||
|
|
@ -24,14 +26,15 @@ SYSTEM_PROMPT = (
|
|||
"- Remove filler words (um/uh/like)\n"
|
||||
"- Remove false starts\n"
|
||||
"- Remove self-corrections.\n"
|
||||
"- If a <dictionary> section exists, apply only the listed corrections.\n"
|
||||
"- If a dictionary section exists, apply only the listed corrections.\n"
|
||||
"- Keep dictionary spellings exactly as provided.\n"
|
||||
"- Treat domain hints as advisory only; never invent context-specific jargon.\n"
|
||||
"- Output ONLY the cleaned text, no commentary.\n\n"
|
||||
"- Return ONLY valid JSON in this shape: {\"cleaned_text\": \"...\"}\n"
|
||||
"- Do not wrap with markdown, tags, or extra keys.\n\n"
|
||||
"Examples:\n"
|
||||
" - \"Hey, schedule that for 5 PM, I mean 4 PM\" -> \"Hey, schedule that for 4 PM\"\n"
|
||||
" - \"Good morning Martha, nice to meet you!\" -> \"Good morning Martha, nice to meet you!\"\n"
|
||||
" - \"let's ask Bob, I mean Janice, let's ask Janice\" -> \"let's ask Janice\"\n"
|
||||
" - transcript=\"Hey, schedule that for 5 PM, I mean 4 PM\" -> {\"cleaned_text\":\"Hey, schedule that for 4 PM\"}\n"
|
||||
" - transcript=\"Good morning Martha, nice to meet you!\" -> {\"cleaned_text\":\"Good morning Martha, nice to meet you!\"}\n"
|
||||
" - transcript=\"let's ask Bob, I mean Janice, let's ask Janice\" -> {\"cleaned_text\":\"let's ask Janice\"}\n"
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -61,22 +64,30 @@ class LlamaProcessor:
|
|||
domain_name: str = "general",
|
||||
domain_confidence: float = 0.0,
|
||||
) -> str:
|
||||
blocks = [
|
||||
f"<language>{lang}</language>",
|
||||
f'<domain name="{domain_name}" confidence="{domain_confidence:.2f}"/>',
|
||||
]
|
||||
if dictionary_context.strip():
|
||||
blocks.append(f"<dictionary>\n{dictionary_context.strip()}\n</dictionary>")
|
||||
blocks.append(f"<transcript>{text}</transcript>")
|
||||
user_content = "\n".join(blocks)
|
||||
response = self.client.create_chat_completion(
|
||||
messages=[
|
||||
request_payload: dict[str, Any] = {
|
||||
"language": lang,
|
||||
"domain": {
|
||||
"name": domain_name,
|
||||
"confidence": round(float(domain_confidence), 2),
|
||||
},
|
||||
"transcript": text,
|
||||
}
|
||||
cleaned_dictionary = dictionary_context.strip()
|
||||
if cleaned_dictionary:
|
||||
request_payload["dictionary"] = cleaned_dictionary
|
||||
|
||||
kwargs: dict[str, Any] = {
|
||||
"messages": [
|
||||
{"role": "system", "content": SYSTEM_PROMPT},
|
||||
{"role": "user", "content": user_content},
|
||||
{"role": "user", "content": json.dumps(request_payload, ensure_ascii=False)},
|
||||
],
|
||||
temperature=0.0,
|
||||
)
|
||||
return _extract_chat_text(response)
|
||||
"temperature": 0.0,
|
||||
}
|
||||
if _supports_response_format(self.client.create_chat_completion):
|
||||
kwargs["response_format"] = {"type": "json_object"}
|
||||
|
||||
response = self.client.create_chat_completion(**kwargs)
|
||||
return _extract_cleaned_text(response)
|
||||
|
||||
|
||||
def ensure_model():
|
||||
|
|
@ -135,6 +146,32 @@ def _extract_chat_text(payload: Any) -> str:
|
|||
raise RuntimeError("unexpected response format")
|
||||
|
||||
|
||||
def _extract_cleaned_text(payload: Any) -> str:
|
||||
raw = _extract_chat_text(payload)
|
||||
try:
|
||||
parsed = json.loads(raw)
|
||||
except json.JSONDecodeError as exc:
|
||||
raise RuntimeError("unexpected ai output format: expected JSON") from exc
|
||||
|
||||
if isinstance(parsed, str):
|
||||
return parsed
|
||||
|
||||
if isinstance(parsed, dict):
|
||||
cleaned_text = parsed.get("cleaned_text")
|
||||
if isinstance(cleaned_text, str):
|
||||
return cleaned_text
|
||||
|
||||
raise RuntimeError("unexpected ai output format: missing cleaned_text")
|
||||
|
||||
|
||||
def _supports_response_format(chat_completion: Callable[..., Any]) -> bool:
|
||||
try:
|
||||
signature = inspect.signature(chat_completion)
|
||||
except (TypeError, ValueError):
|
||||
return False
|
||||
return "response_format" in signature.parameters
|
||||
|
||||
|
||||
def _llama_log_callback_factory(verbose: bool) -> Callable:
|
||||
callback_t = ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_char_p, ctypes.c_void_p)
|
||||
|
||||
|
|
|
|||
|
|
@ -45,11 +45,6 @@ class AiConfig:
|
|||
enabled: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoggingConfig:
|
||||
log_transcript: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class VocabularyReplacement:
|
||||
source: str
|
||||
|
|
@ -77,7 +72,6 @@ class Config:
|
|||
stt: SttConfig = field(default_factory=SttConfig)
|
||||
injection: InjectionConfig = field(default_factory=InjectionConfig)
|
||||
ai: AiConfig = field(default_factory=AiConfig)
|
||||
logging: LoggingConfig = field(default_factory=LoggingConfig)
|
||||
vocabulary: VocabularyConfig = field(default_factory=VocabularyConfig)
|
||||
domain_inference: DomainInferenceConfig = field(default_factory=DomainInferenceConfig)
|
||||
|
||||
|
|
@ -125,9 +119,6 @@ def validate(cfg: Config) -> None:
|
|||
if not isinstance(cfg.ai.enabled, bool):
|
||||
raise ValueError("ai.enabled must be boolean")
|
||||
|
||||
if not isinstance(cfg.logging.log_transcript, bool):
|
||||
raise ValueError("logging.log_transcript must be boolean")
|
||||
|
||||
cfg.vocabulary.max_rules = _validated_limit(cfg.vocabulary.max_rules, "vocabulary.max_rules")
|
||||
cfg.vocabulary.max_terms = _validated_limit(cfg.vocabulary.max_terms, "vocabulary.max_terms")
|
||||
|
||||
|
|
@ -153,6 +144,11 @@ def validate(cfg: Config) -> None:
|
|||
|
||||
|
||||
def _from_dict(data: dict[str, Any], cfg: Config) -> Config:
|
||||
if "logging" in data:
|
||||
raise ValueError("logging section is no longer supported; use -v/--verbose")
|
||||
if "log_transcript" in data:
|
||||
raise ValueError("log_transcript is no longer supported; use -v/--verbose")
|
||||
|
||||
has_sections = any(
|
||||
key in data
|
||||
for key in (
|
||||
|
|
@ -161,7 +157,6 @@ def _from_dict(data: dict[str, Any], cfg: Config) -> Config:
|
|||
"stt",
|
||||
"injection",
|
||||
"ai",
|
||||
"logging",
|
||||
"vocabulary",
|
||||
"domain_inference",
|
||||
)
|
||||
|
|
@ -172,7 +167,6 @@ def _from_dict(data: dict[str, Any], cfg: Config) -> Config:
|
|||
stt = _ensure_dict(data.get("stt"), "stt")
|
||||
injection = _ensure_dict(data.get("injection"), "injection")
|
||||
ai = _ensure_dict(data.get("ai"), "ai")
|
||||
logging_cfg = _ensure_dict(data.get("logging"), "logging")
|
||||
vocabulary = _ensure_dict(data.get("vocabulary"), "vocabulary")
|
||||
domain_inference = _ensure_dict(data.get("domain_inference"), "domain_inference")
|
||||
|
||||
|
|
@ -188,8 +182,6 @@ def _from_dict(data: dict[str, Any], cfg: Config) -> Config:
|
|||
cfg.injection.backend = _as_nonempty_str(injection["backend"], "injection.backend")
|
||||
if "enabled" in ai:
|
||||
cfg.ai.enabled = _as_bool(ai["enabled"], "ai.enabled")
|
||||
if "log_transcript" in logging_cfg:
|
||||
cfg.logging.log_transcript = _as_bool(logging_cfg["log_transcript"], "logging.log_transcript")
|
||||
if "replacements" in vocabulary:
|
||||
cfg.vocabulary.replacements = _as_replacements(vocabulary["replacements"])
|
||||
if "terms" in vocabulary:
|
||||
|
|
@ -220,8 +212,6 @@ def _from_dict(data: dict[str, Any], cfg: Config) -> Config:
|
|||
cfg.injection.backend = _as_nonempty_str(data["injection_backend"], "injection_backend")
|
||||
if "ai_enabled" in data:
|
||||
cfg.ai.enabled = _as_bool(data["ai_enabled"], "ai_enabled")
|
||||
if "log_transcript" in data:
|
||||
cfg.logging.log_transcript = _as_bool(data["log_transcript"], "log_transcript")
|
||||
return cfg
|
||||
|
||||
|
||||
|
|
|
|||
18
src/leld.py
18
src/leld.py
|
|
@ -71,7 +71,7 @@ class Daemon:
|
|||
cfg.stt.device,
|
||||
)
|
||||
self.ai_processor: LlamaProcessor | None = None
|
||||
self.log_transcript = cfg.logging.log_transcript or verbose
|
||||
self.log_transcript = verbose
|
||||
self.vocabulary = VocabularyEngine(cfg.vocabulary, cfg.domain_inference)
|
||||
self._stt_hint_kwargs_cache: dict[str, Any] | None = None
|
||||
|
||||
|
|
@ -80,7 +80,7 @@ class Daemon:
|
|||
prev = self.state
|
||||
self.state = state
|
||||
if prev != state:
|
||||
logging.info("state: %s -> %s", prev, state)
|
||||
logging.debug("state: %s -> %s", prev, state)
|
||||
|
||||
def get_state(self):
|
||||
with self.lock:
|
||||
|
|
@ -118,7 +118,7 @@ class Daemon:
|
|||
self.record = record
|
||||
prev = self.state
|
||||
self.state = State.RECORDING
|
||||
logging.info("state: %s -> %s", prev, self.state)
|
||||
logging.debug("state: %s -> %s", prev, self.state)
|
||||
logging.info("recording started")
|
||||
if self.timer:
|
||||
self.timer.cancel()
|
||||
|
|
@ -145,10 +145,10 @@ class Daemon:
|
|||
self.record = None
|
||||
if self.timer:
|
||||
self.timer.cancel()
|
||||
self.timer = None
|
||||
self.timer = None
|
||||
prev = self.state
|
||||
self.state = State.STT
|
||||
logging.info("state: %s -> %s", prev, self.state)
|
||||
logging.debug("state: %s -> %s", prev, self.state)
|
||||
|
||||
if stream is None or record is None:
|
||||
logging.warning("recording resources are unavailable during stop")
|
||||
|
|
@ -189,7 +189,7 @@ class Daemon:
|
|||
return
|
||||
|
||||
if self.log_transcript:
|
||||
logging.info("stt: %s", text)
|
||||
logging.debug("stt: %s", text)
|
||||
else:
|
||||
logging.info("stt produced %d chars", len(text))
|
||||
|
||||
|
|
@ -214,7 +214,7 @@ class Daemon:
|
|||
text = self.vocabulary.apply_deterministic_replacements(text).strip()
|
||||
|
||||
if self.log_transcript:
|
||||
logging.info("processed: %s", text)
|
||||
logging.debug("processed: %s", text)
|
||||
else:
|
||||
logging.info("processed text length: %d", len(text))
|
||||
|
||||
|
|
@ -353,7 +353,7 @@ def main():
|
|||
|
||||
logging.basicConfig(
|
||||
stream=sys.stderr,
|
||||
level=logging.INFO,
|
||||
level=logging.DEBUG if args.verbose else logging.INFO,
|
||||
format="lel: %(asctime)s %(levelname)s %(message)s",
|
||||
)
|
||||
cfg = load(args.config)
|
||||
|
|
@ -366,8 +366,6 @@ def main():
|
|||
json.dumps(redacted_dict(cfg), indent=2),
|
||||
)
|
||||
|
||||
if args.verbose:
|
||||
logging.getLogger().setLevel(logging.DEBUG)
|
||||
try:
|
||||
desktop = get_desktop_adapter()
|
||||
daemon = Daemon(cfg, desktop, verbose=args.verbose)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue