Remove log_transcript config and enforce JSON AI output
This commit is contained in:
parent
c3503fbbde
commit
1423e44008
8 changed files with 198 additions and 62 deletions
|
|
@ -92,7 +92,6 @@ Create `~/.config/lel/config.json`:
|
||||||
"stt": { "model": "base", "device": "cpu" },
|
"stt": { "model": "base", "device": "cpu" },
|
||||||
"injection": { "backend": "clipboard" },
|
"injection": { "backend": "clipboard" },
|
||||||
"ai": { "enabled": true },
|
"ai": { "enabled": true },
|
||||||
"logging": { "log_transcript": false },
|
|
||||||
"vocabulary": {
|
"vocabulary": {
|
||||||
"replacements": [
|
"replacements": [
|
||||||
{ "from": "Martha", "to": "Marta" },
|
{ "from": "Martha", "to": "Marta" },
|
||||||
|
|
@ -113,9 +112,9 @@ name.
|
||||||
AI cleanup is always enabled and uses the locked local Llama-3.2-3B GGUF model
|
AI cleanup is always enabled and uses the locked local Llama-3.2-3B GGUF model
|
||||||
downloaded to `~/.cache/lel/models/` on first use.
|
downloaded to `~/.cache/lel/models/` on first use.
|
||||||
|
|
||||||
`logging.log_transcript` controls whether recognized/processed text is written
|
Use `-v/--verbose` to enable DEBUG logs, including recognized/processed
|
||||||
to logs. This is disabled by default. `-v/--verbose` also enables transcript
|
transcript text and llama.cpp logs (`llama::` prefix). Without `-v`, logs are
|
||||||
logging and llama.cpp logs; llama logs are prefixed with `llama::`.
|
INFO level.
|
||||||
|
|
||||||
Vocabulary correction:
|
Vocabulary correction:
|
||||||
|
|
||||||
|
|
@ -152,7 +151,7 @@ systemctl --user enable --now lel
|
||||||
- Press the hotkey once to start recording.
|
- Press the hotkey once to start recording.
|
||||||
- Press it again to stop and run STT.
|
- Press it again to stop and run STT.
|
||||||
- Press `Esc` while recording to cancel without processing.
|
- Press `Esc` while recording to cancel without processing.
|
||||||
- Transcript contents are logged only when `logging.log_transcript` is enabled or `-v/--verbose` is used.
|
- Transcript contents are logged only when `-v/--verbose` is used.
|
||||||
|
|
||||||
Wayland note:
|
Wayland note:
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -15,9 +15,6 @@
|
||||||
"ai": {
|
"ai": {
|
||||||
"enabled": true
|
"enabled": true
|
||||||
},
|
},
|
||||||
"logging": {
|
|
||||||
"log_transcript": true
|
|
||||||
},
|
|
||||||
"vocabulary": {
|
"vocabulary": {
|
||||||
"replacements": [
|
"replacements": [
|
||||||
{
|
{
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,8 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import ctypes
|
import ctypes
|
||||||
|
import inspect
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
@ -12,7 +14,7 @@ from constants import MODEL_DIR, MODEL_NAME, MODEL_PATH, MODEL_URL
|
||||||
|
|
||||||
SYSTEM_PROMPT = (
|
SYSTEM_PROMPT = (
|
||||||
"You are an amanuensis working for an user.\n"
|
"You are an amanuensis working for an user.\n"
|
||||||
"You'll receive an accurate transcript of what the user said.\n"
|
"You'll receive a JSON object with the transcript and optional context.\n"
|
||||||
"Your job is to rewrite the user's transcript into clean prose.\n"
|
"Your job is to rewrite the user's transcript into clean prose.\n"
|
||||||
"Your output will be directly pasted in the currently focused application on the user computer.\n\n"
|
"Your output will be directly pasted in the currently focused application on the user computer.\n\n"
|
||||||
|
|
||||||
|
|
@ -24,14 +26,15 @@ SYSTEM_PROMPT = (
|
||||||
"- Remove filler words (um/uh/like)\n"
|
"- Remove filler words (um/uh/like)\n"
|
||||||
"- Remove false starts\n"
|
"- Remove false starts\n"
|
||||||
"- Remove self-corrections.\n"
|
"- Remove self-corrections.\n"
|
||||||
"- If a <dictionary> section exists, apply only the listed corrections.\n"
|
"- If a dictionary section exists, apply only the listed corrections.\n"
|
||||||
"- Keep dictionary spellings exactly as provided.\n"
|
"- Keep dictionary spellings exactly as provided.\n"
|
||||||
"- Treat domain hints as advisory only; never invent context-specific jargon.\n"
|
"- Treat domain hints as advisory only; never invent context-specific jargon.\n"
|
||||||
"- Output ONLY the cleaned text, no commentary.\n\n"
|
"- Return ONLY valid JSON in this shape: {\"cleaned_text\": \"...\"}\n"
|
||||||
|
"- Do not wrap with markdown, tags, or extra keys.\n\n"
|
||||||
"Examples:\n"
|
"Examples:\n"
|
||||||
" - \"Hey, schedule that for 5 PM, I mean 4 PM\" -> \"Hey, schedule that for 4 PM\"\n"
|
" - transcript=\"Hey, schedule that for 5 PM, I mean 4 PM\" -> {\"cleaned_text\":\"Hey, schedule that for 4 PM\"}\n"
|
||||||
" - \"Good morning Martha, nice to meet you!\" -> \"Good morning Martha, nice to meet you!\"\n"
|
" - transcript=\"Good morning Martha, nice to meet you!\" -> {\"cleaned_text\":\"Good morning Martha, nice to meet you!\"}\n"
|
||||||
" - \"let's ask Bob, I mean Janice, let's ask Janice\" -> \"let's ask Janice\"\n"
|
" - transcript=\"let's ask Bob, I mean Janice, let's ask Janice\" -> {\"cleaned_text\":\"let's ask Janice\"}\n"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -61,22 +64,30 @@ class LlamaProcessor:
|
||||||
domain_name: str = "general",
|
domain_name: str = "general",
|
||||||
domain_confidence: float = 0.0,
|
domain_confidence: float = 0.0,
|
||||||
) -> str:
|
) -> str:
|
||||||
blocks = [
|
request_payload: dict[str, Any] = {
|
||||||
f"<language>{lang}</language>",
|
"language": lang,
|
||||||
f'<domain name="{domain_name}" confidence="{domain_confidence:.2f}"/>',
|
"domain": {
|
||||||
]
|
"name": domain_name,
|
||||||
if dictionary_context.strip():
|
"confidence": round(float(domain_confidence), 2),
|
||||||
blocks.append(f"<dictionary>\n{dictionary_context.strip()}\n</dictionary>")
|
},
|
||||||
blocks.append(f"<transcript>{text}</transcript>")
|
"transcript": text,
|
||||||
user_content = "\n".join(blocks)
|
}
|
||||||
response = self.client.create_chat_completion(
|
cleaned_dictionary = dictionary_context.strip()
|
||||||
messages=[
|
if cleaned_dictionary:
|
||||||
|
request_payload["dictionary"] = cleaned_dictionary
|
||||||
|
|
||||||
|
kwargs: dict[str, Any] = {
|
||||||
|
"messages": [
|
||||||
{"role": "system", "content": SYSTEM_PROMPT},
|
{"role": "system", "content": SYSTEM_PROMPT},
|
||||||
{"role": "user", "content": user_content},
|
{"role": "user", "content": json.dumps(request_payload, ensure_ascii=False)},
|
||||||
],
|
],
|
||||||
temperature=0.0,
|
"temperature": 0.0,
|
||||||
)
|
}
|
||||||
return _extract_chat_text(response)
|
if _supports_response_format(self.client.create_chat_completion):
|
||||||
|
kwargs["response_format"] = {"type": "json_object"}
|
||||||
|
|
||||||
|
response = self.client.create_chat_completion(**kwargs)
|
||||||
|
return _extract_cleaned_text(response)
|
||||||
|
|
||||||
|
|
||||||
def ensure_model():
|
def ensure_model():
|
||||||
|
|
@ -135,6 +146,32 @@ def _extract_chat_text(payload: Any) -> str:
|
||||||
raise RuntimeError("unexpected response format")
|
raise RuntimeError("unexpected response format")
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_cleaned_text(payload: Any) -> str:
|
||||||
|
raw = _extract_chat_text(payload)
|
||||||
|
try:
|
||||||
|
parsed = json.loads(raw)
|
||||||
|
except json.JSONDecodeError as exc:
|
||||||
|
raise RuntimeError("unexpected ai output format: expected JSON") from exc
|
||||||
|
|
||||||
|
if isinstance(parsed, str):
|
||||||
|
return parsed
|
||||||
|
|
||||||
|
if isinstance(parsed, dict):
|
||||||
|
cleaned_text = parsed.get("cleaned_text")
|
||||||
|
if isinstance(cleaned_text, str):
|
||||||
|
return cleaned_text
|
||||||
|
|
||||||
|
raise RuntimeError("unexpected ai output format: missing cleaned_text")
|
||||||
|
|
||||||
|
|
||||||
|
def _supports_response_format(chat_completion: Callable[..., Any]) -> bool:
|
||||||
|
try:
|
||||||
|
signature = inspect.signature(chat_completion)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return False
|
||||||
|
return "response_format" in signature.parameters
|
||||||
|
|
||||||
|
|
||||||
def _llama_log_callback_factory(verbose: bool) -> Callable:
|
def _llama_log_callback_factory(verbose: bool) -> Callable:
|
||||||
callback_t = ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_char_p, ctypes.c_void_p)
|
callback_t = ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_char_p, ctypes.c_void_p)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -45,11 +45,6 @@ class AiConfig:
|
||||||
enabled: bool = True
|
enabled: bool = True
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class LoggingConfig:
|
|
||||||
log_transcript: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class VocabularyReplacement:
|
class VocabularyReplacement:
|
||||||
source: str
|
source: str
|
||||||
|
|
@ -77,7 +72,6 @@ class Config:
|
||||||
stt: SttConfig = field(default_factory=SttConfig)
|
stt: SttConfig = field(default_factory=SttConfig)
|
||||||
injection: InjectionConfig = field(default_factory=InjectionConfig)
|
injection: InjectionConfig = field(default_factory=InjectionConfig)
|
||||||
ai: AiConfig = field(default_factory=AiConfig)
|
ai: AiConfig = field(default_factory=AiConfig)
|
||||||
logging: LoggingConfig = field(default_factory=LoggingConfig)
|
|
||||||
vocabulary: VocabularyConfig = field(default_factory=VocabularyConfig)
|
vocabulary: VocabularyConfig = field(default_factory=VocabularyConfig)
|
||||||
domain_inference: DomainInferenceConfig = field(default_factory=DomainInferenceConfig)
|
domain_inference: DomainInferenceConfig = field(default_factory=DomainInferenceConfig)
|
||||||
|
|
||||||
|
|
@ -125,9 +119,6 @@ def validate(cfg: Config) -> None:
|
||||||
if not isinstance(cfg.ai.enabled, bool):
|
if not isinstance(cfg.ai.enabled, bool):
|
||||||
raise ValueError("ai.enabled must be boolean")
|
raise ValueError("ai.enabled must be boolean")
|
||||||
|
|
||||||
if not isinstance(cfg.logging.log_transcript, bool):
|
|
||||||
raise ValueError("logging.log_transcript must be boolean")
|
|
||||||
|
|
||||||
cfg.vocabulary.max_rules = _validated_limit(cfg.vocabulary.max_rules, "vocabulary.max_rules")
|
cfg.vocabulary.max_rules = _validated_limit(cfg.vocabulary.max_rules, "vocabulary.max_rules")
|
||||||
cfg.vocabulary.max_terms = _validated_limit(cfg.vocabulary.max_terms, "vocabulary.max_terms")
|
cfg.vocabulary.max_terms = _validated_limit(cfg.vocabulary.max_terms, "vocabulary.max_terms")
|
||||||
|
|
||||||
|
|
@ -153,6 +144,11 @@ def validate(cfg: Config) -> None:
|
||||||
|
|
||||||
|
|
||||||
def _from_dict(data: dict[str, Any], cfg: Config) -> Config:
|
def _from_dict(data: dict[str, Any], cfg: Config) -> Config:
|
||||||
|
if "logging" in data:
|
||||||
|
raise ValueError("logging section is no longer supported; use -v/--verbose")
|
||||||
|
if "log_transcript" in data:
|
||||||
|
raise ValueError("log_transcript is no longer supported; use -v/--verbose")
|
||||||
|
|
||||||
has_sections = any(
|
has_sections = any(
|
||||||
key in data
|
key in data
|
||||||
for key in (
|
for key in (
|
||||||
|
|
@ -161,7 +157,6 @@ def _from_dict(data: dict[str, Any], cfg: Config) -> Config:
|
||||||
"stt",
|
"stt",
|
||||||
"injection",
|
"injection",
|
||||||
"ai",
|
"ai",
|
||||||
"logging",
|
|
||||||
"vocabulary",
|
"vocabulary",
|
||||||
"domain_inference",
|
"domain_inference",
|
||||||
)
|
)
|
||||||
|
|
@ -172,7 +167,6 @@ def _from_dict(data: dict[str, Any], cfg: Config) -> Config:
|
||||||
stt = _ensure_dict(data.get("stt"), "stt")
|
stt = _ensure_dict(data.get("stt"), "stt")
|
||||||
injection = _ensure_dict(data.get("injection"), "injection")
|
injection = _ensure_dict(data.get("injection"), "injection")
|
||||||
ai = _ensure_dict(data.get("ai"), "ai")
|
ai = _ensure_dict(data.get("ai"), "ai")
|
||||||
logging_cfg = _ensure_dict(data.get("logging"), "logging")
|
|
||||||
vocabulary = _ensure_dict(data.get("vocabulary"), "vocabulary")
|
vocabulary = _ensure_dict(data.get("vocabulary"), "vocabulary")
|
||||||
domain_inference = _ensure_dict(data.get("domain_inference"), "domain_inference")
|
domain_inference = _ensure_dict(data.get("domain_inference"), "domain_inference")
|
||||||
|
|
||||||
|
|
@ -188,8 +182,6 @@ def _from_dict(data: dict[str, Any], cfg: Config) -> Config:
|
||||||
cfg.injection.backend = _as_nonempty_str(injection["backend"], "injection.backend")
|
cfg.injection.backend = _as_nonempty_str(injection["backend"], "injection.backend")
|
||||||
if "enabled" in ai:
|
if "enabled" in ai:
|
||||||
cfg.ai.enabled = _as_bool(ai["enabled"], "ai.enabled")
|
cfg.ai.enabled = _as_bool(ai["enabled"], "ai.enabled")
|
||||||
if "log_transcript" in logging_cfg:
|
|
||||||
cfg.logging.log_transcript = _as_bool(logging_cfg["log_transcript"], "logging.log_transcript")
|
|
||||||
if "replacements" in vocabulary:
|
if "replacements" in vocabulary:
|
||||||
cfg.vocabulary.replacements = _as_replacements(vocabulary["replacements"])
|
cfg.vocabulary.replacements = _as_replacements(vocabulary["replacements"])
|
||||||
if "terms" in vocabulary:
|
if "terms" in vocabulary:
|
||||||
|
|
@ -220,8 +212,6 @@ def _from_dict(data: dict[str, Any], cfg: Config) -> Config:
|
||||||
cfg.injection.backend = _as_nonempty_str(data["injection_backend"], "injection_backend")
|
cfg.injection.backend = _as_nonempty_str(data["injection_backend"], "injection_backend")
|
||||||
if "ai_enabled" in data:
|
if "ai_enabled" in data:
|
||||||
cfg.ai.enabled = _as_bool(data["ai_enabled"], "ai_enabled")
|
cfg.ai.enabled = _as_bool(data["ai_enabled"], "ai_enabled")
|
||||||
if "log_transcript" in data:
|
|
||||||
cfg.logging.log_transcript = _as_bool(data["log_transcript"], "log_transcript")
|
|
||||||
return cfg
|
return cfg
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
18
src/leld.py
18
src/leld.py
|
|
@ -71,7 +71,7 @@ class Daemon:
|
||||||
cfg.stt.device,
|
cfg.stt.device,
|
||||||
)
|
)
|
||||||
self.ai_processor: LlamaProcessor | None = None
|
self.ai_processor: LlamaProcessor | None = None
|
||||||
self.log_transcript = cfg.logging.log_transcript or verbose
|
self.log_transcript = verbose
|
||||||
self.vocabulary = VocabularyEngine(cfg.vocabulary, cfg.domain_inference)
|
self.vocabulary = VocabularyEngine(cfg.vocabulary, cfg.domain_inference)
|
||||||
self._stt_hint_kwargs_cache: dict[str, Any] | None = None
|
self._stt_hint_kwargs_cache: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
@ -80,7 +80,7 @@ class Daemon:
|
||||||
prev = self.state
|
prev = self.state
|
||||||
self.state = state
|
self.state = state
|
||||||
if prev != state:
|
if prev != state:
|
||||||
logging.info("state: %s -> %s", prev, state)
|
logging.debug("state: %s -> %s", prev, state)
|
||||||
|
|
||||||
def get_state(self):
|
def get_state(self):
|
||||||
with self.lock:
|
with self.lock:
|
||||||
|
|
@ -118,7 +118,7 @@ class Daemon:
|
||||||
self.record = record
|
self.record = record
|
||||||
prev = self.state
|
prev = self.state
|
||||||
self.state = State.RECORDING
|
self.state = State.RECORDING
|
||||||
logging.info("state: %s -> %s", prev, self.state)
|
logging.debug("state: %s -> %s", prev, self.state)
|
||||||
logging.info("recording started")
|
logging.info("recording started")
|
||||||
if self.timer:
|
if self.timer:
|
||||||
self.timer.cancel()
|
self.timer.cancel()
|
||||||
|
|
@ -145,10 +145,10 @@ class Daemon:
|
||||||
self.record = None
|
self.record = None
|
||||||
if self.timer:
|
if self.timer:
|
||||||
self.timer.cancel()
|
self.timer.cancel()
|
||||||
self.timer = None
|
self.timer = None
|
||||||
prev = self.state
|
prev = self.state
|
||||||
self.state = State.STT
|
self.state = State.STT
|
||||||
logging.info("state: %s -> %s", prev, self.state)
|
logging.debug("state: %s -> %s", prev, self.state)
|
||||||
|
|
||||||
if stream is None or record is None:
|
if stream is None or record is None:
|
||||||
logging.warning("recording resources are unavailable during stop")
|
logging.warning("recording resources are unavailable during stop")
|
||||||
|
|
@ -189,7 +189,7 @@ class Daemon:
|
||||||
return
|
return
|
||||||
|
|
||||||
if self.log_transcript:
|
if self.log_transcript:
|
||||||
logging.info("stt: %s", text)
|
logging.debug("stt: %s", text)
|
||||||
else:
|
else:
|
||||||
logging.info("stt produced %d chars", len(text))
|
logging.info("stt produced %d chars", len(text))
|
||||||
|
|
||||||
|
|
@ -214,7 +214,7 @@ class Daemon:
|
||||||
text = self.vocabulary.apply_deterministic_replacements(text).strip()
|
text = self.vocabulary.apply_deterministic_replacements(text).strip()
|
||||||
|
|
||||||
if self.log_transcript:
|
if self.log_transcript:
|
||||||
logging.info("processed: %s", text)
|
logging.debug("processed: %s", text)
|
||||||
else:
|
else:
|
||||||
logging.info("processed text length: %d", len(text))
|
logging.info("processed text length: %d", len(text))
|
||||||
|
|
||||||
|
|
@ -353,7 +353,7 @@ def main():
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
stream=sys.stderr,
|
stream=sys.stderr,
|
||||||
level=logging.INFO,
|
level=logging.DEBUG if args.verbose else logging.INFO,
|
||||||
format="lel: %(asctime)s %(levelname)s %(message)s",
|
format="lel: %(asctime)s %(levelname)s %(message)s",
|
||||||
)
|
)
|
||||||
cfg = load(args.config)
|
cfg = load(args.config)
|
||||||
|
|
@ -366,8 +366,6 @@ def main():
|
||||||
json.dumps(redacted_dict(cfg), indent=2),
|
json.dumps(redacted_dict(cfg), indent=2),
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.verbose:
|
|
||||||
logging.getLogger().setLevel(logging.DEBUG)
|
|
||||||
try:
|
try:
|
||||||
desktop = get_desktop_adapter()
|
desktop = get_desktop_adapter()
|
||||||
daemon = Daemon(cfg, desktop, verbose=args.verbose)
|
daemon = Daemon(cfg, desktop, verbose=args.verbose)
|
||||||
|
|
|
||||||
88
tests/test_aiprocess.py
Normal file
88
tests/test_aiprocess.py
Normal file
|
|
@ -0,0 +1,88 @@
|
||||||
|
import sys
|
||||||
|
import unittest
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
ROOT = Path(__file__).resolve().parents[1]
|
||||||
|
SRC = ROOT / "src"
|
||||||
|
if str(SRC) not in sys.path:
|
||||||
|
sys.path.insert(0, str(SRC))
|
||||||
|
|
||||||
|
from aiprocess import _extract_cleaned_text, _supports_response_format
|
||||||
|
|
||||||
|
|
||||||
|
class ExtractCleanedTextTests(unittest.TestCase):
|
||||||
|
def test_extracts_cleaned_text_from_json_object(self):
|
||||||
|
payload = {
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"message": {
|
||||||
|
"content": '{"cleaned_text":"Hello <transcript>literal</transcript> world"}'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
result = _extract_cleaned_text(payload)
|
||||||
|
|
||||||
|
self.assertEqual(result, "Hello <transcript>literal</transcript> world")
|
||||||
|
|
||||||
|
def test_extracts_cleaned_text_from_json_string(self):
|
||||||
|
payload = {
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"message": {
|
||||||
|
"content": '"He said \\\"hello\\\""'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
result = _extract_cleaned_text(payload)
|
||||||
|
|
||||||
|
self.assertEqual(result, 'He said "hello"')
|
||||||
|
|
||||||
|
def test_rejects_non_json_output(self):
|
||||||
|
payload = {
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"message": {
|
||||||
|
"content": "<transcript>Hello</transcript>"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(RuntimeError, "expected JSON"):
|
||||||
|
_extract_cleaned_text(payload)
|
||||||
|
|
||||||
|
def test_rejects_json_without_required_key(self):
|
||||||
|
payload = {
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"message": {
|
||||||
|
"content": '{"text":"hello"}'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(RuntimeError, "missing cleaned_text"):
|
||||||
|
_extract_cleaned_text(payload)
|
||||||
|
|
||||||
|
|
||||||
|
class SupportsResponseFormatTests(unittest.TestCase):
|
||||||
|
def test_supports_response_format_when_parameter_exists(self):
|
||||||
|
def chat_completion(*, messages, temperature, response_format):
|
||||||
|
return None
|
||||||
|
|
||||||
|
self.assertTrue(_supports_response_format(chat_completion))
|
||||||
|
|
||||||
|
def test_does_not_support_response_format_when_missing(self):
|
||||||
|
def chat_completion(*, messages, temperature):
|
||||||
|
return None
|
||||||
|
|
||||||
|
self.assertFalse(_supports_response_format(chat_completion))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
|
|
@ -26,7 +26,6 @@ class ConfigTests(unittest.TestCase):
|
||||||
self.assertEqual(cfg.stt.device, "cpu")
|
self.assertEqual(cfg.stt.device, "cpu")
|
||||||
self.assertEqual(cfg.injection.backend, "clipboard")
|
self.assertEqual(cfg.injection.backend, "clipboard")
|
||||||
self.assertTrue(cfg.ai.enabled)
|
self.assertTrue(cfg.ai.enabled)
|
||||||
self.assertFalse(cfg.logging.log_transcript)
|
|
||||||
self.assertEqual(cfg.vocabulary.replacements, [])
|
self.assertEqual(cfg.vocabulary.replacements, [])
|
||||||
self.assertEqual(cfg.vocabulary.terms, [])
|
self.assertEqual(cfg.vocabulary.terms, [])
|
||||||
self.assertEqual(cfg.vocabulary.max_rules, 500)
|
self.assertEqual(cfg.vocabulary.max_rules, 500)
|
||||||
|
|
@ -41,7 +40,6 @@ class ConfigTests(unittest.TestCase):
|
||||||
"stt": {"model": "small", "device": "cuda"},
|
"stt": {"model": "small", "device": "cuda"},
|
||||||
"injection": {"backend": "injection"},
|
"injection": {"backend": "injection"},
|
||||||
"ai": {"enabled": False},
|
"ai": {"enabled": False},
|
||||||
"logging": {"log_transcript": True},
|
|
||||||
"vocabulary": {
|
"vocabulary": {
|
||||||
"replacements": [
|
"replacements": [
|
||||||
{"from": "Martha", "to": "Marta"},
|
{"from": "Martha", "to": "Marta"},
|
||||||
|
|
@ -65,7 +63,6 @@ class ConfigTests(unittest.TestCase):
|
||||||
self.assertEqual(cfg.stt.device, "cuda")
|
self.assertEqual(cfg.stt.device, "cuda")
|
||||||
self.assertEqual(cfg.injection.backend, "injection")
|
self.assertEqual(cfg.injection.backend, "injection")
|
||||||
self.assertFalse(cfg.ai.enabled)
|
self.assertFalse(cfg.ai.enabled)
|
||||||
self.assertTrue(cfg.logging.log_transcript)
|
|
||||||
self.assertEqual(cfg.vocabulary.max_rules, 100)
|
self.assertEqual(cfg.vocabulary.max_rules, 100)
|
||||||
self.assertEqual(cfg.vocabulary.max_terms, 200)
|
self.assertEqual(cfg.vocabulary.max_terms, 200)
|
||||||
self.assertEqual(len(cfg.vocabulary.replacements), 2)
|
self.assertEqual(len(cfg.vocabulary.replacements), 2)
|
||||||
|
|
@ -83,7 +80,6 @@ class ConfigTests(unittest.TestCase):
|
||||||
"whisper_device": "cpu",
|
"whisper_device": "cpu",
|
||||||
"injection_backend": "clipboard",
|
"injection_backend": "clipboard",
|
||||||
"ai_enabled": False,
|
"ai_enabled": False,
|
||||||
"log_transcript": True,
|
|
||||||
}
|
}
|
||||||
with tempfile.TemporaryDirectory() as td:
|
with tempfile.TemporaryDirectory() as td:
|
||||||
path = Path(td) / "config.json"
|
path = Path(td) / "config.json"
|
||||||
|
|
@ -97,7 +93,6 @@ class ConfigTests(unittest.TestCase):
|
||||||
self.assertEqual(cfg.stt.device, "cpu")
|
self.assertEqual(cfg.stt.device, "cpu")
|
||||||
self.assertEqual(cfg.injection.backend, "clipboard")
|
self.assertEqual(cfg.injection.backend, "clipboard")
|
||||||
self.assertFalse(cfg.ai.enabled)
|
self.assertFalse(cfg.ai.enabled)
|
||||||
self.assertTrue(cfg.logging.log_transcript)
|
|
||||||
self.assertEqual(cfg.vocabulary.replacements, [])
|
self.assertEqual(cfg.vocabulary.replacements, [])
|
||||||
|
|
||||||
def test_invalid_injection_backend_raises(self):
|
def test_invalid_injection_backend_raises(self):
|
||||||
|
|
@ -109,13 +104,22 @@ class ConfigTests(unittest.TestCase):
|
||||||
with self.assertRaisesRegex(ValueError, "injection.backend"):
|
with self.assertRaisesRegex(ValueError, "injection.backend"):
|
||||||
load(str(path))
|
load(str(path))
|
||||||
|
|
||||||
def test_invalid_logging_flag_raises(self):
|
def test_removed_logging_section_raises(self):
|
||||||
payload = {"logging": {"log_transcript": "yes"}}
|
payload = {"logging": {"log_transcript": True}}
|
||||||
with tempfile.TemporaryDirectory() as td:
|
with tempfile.TemporaryDirectory() as td:
|
||||||
path = Path(td) / "config.json"
|
path = Path(td) / "config.json"
|
||||||
path.write_text(json.dumps(payload), encoding="utf-8")
|
path.write_text(json.dumps(payload), encoding="utf-8")
|
||||||
|
|
||||||
with self.assertRaisesRegex(ValueError, "logging.log_transcript"):
|
with self.assertRaisesRegex(ValueError, "no longer supported"):
|
||||||
|
load(str(path))
|
||||||
|
|
||||||
|
def test_removed_legacy_log_transcript_raises(self):
|
||||||
|
payload = {"log_transcript": True}
|
||||||
|
with tempfile.TemporaryDirectory() as td:
|
||||||
|
path = Path(td) / "config.json"
|
||||||
|
path.write_text(json.dumps(payload), encoding="utf-8")
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(ValueError, "no longer supported"):
|
||||||
load(str(path))
|
load(str(path))
|
||||||
|
|
||||||
def test_conflicting_replacements_raise(self):
|
def test_conflicting_replacements_raise(self):
|
||||||
|
|
|
||||||
|
|
@ -80,7 +80,6 @@ class DaemonTests(unittest.TestCase):
|
||||||
def _config(self) -> Config:
|
def _config(self) -> Config:
|
||||||
cfg = Config()
|
cfg = Config()
|
||||||
cfg.ai.enabled = False
|
cfg.ai.enabled = False
|
||||||
cfg.logging.log_transcript = False
|
|
||||||
return cfg
|
return cfg
|
||||||
|
|
||||||
@patch("leld.stop_audio_recording", return_value=FakeAudio(8))
|
@patch("leld.stop_audio_recording", return_value=FakeAudio(8))
|
||||||
|
|
@ -178,6 +177,30 @@ class DaemonTests(unittest.TestCase):
|
||||||
self.assertIn("Systemd", model.last_kwargs["hotwords"])
|
self.assertIn("Systemd", model.last_kwargs["hotwords"])
|
||||||
self.assertIn("Preferred vocabulary", model.last_kwargs["initial_prompt"])
|
self.assertIn("Preferred vocabulary", model.last_kwargs["initial_prompt"])
|
||||||
|
|
||||||
|
def test_verbose_flag_controls_transcript_logging(self):
|
||||||
|
desktop = FakeDesktop()
|
||||||
|
cfg = self._config()
|
||||||
|
|
||||||
|
with patch("leld._build_whisper_model", return_value=FakeModel()):
|
||||||
|
daemon = leld.Daemon(cfg, desktop, verbose=False)
|
||||||
|
self.assertFalse(daemon.log_transcript)
|
||||||
|
|
||||||
|
with patch("leld._build_whisper_model", return_value=FakeModel()):
|
||||||
|
daemon_verbose = leld.Daemon(cfg, desktop, verbose=True)
|
||||||
|
self.assertTrue(daemon_verbose.log_transcript)
|
||||||
|
|
||||||
|
def test_state_changes_are_debug_level(self):
|
||||||
|
desktop = FakeDesktop()
|
||||||
|
with patch("leld._build_whisper_model", return_value=FakeModel()):
|
||||||
|
daemon = leld.Daemon(self._config(), desktop, verbose=False)
|
||||||
|
|
||||||
|
with self.assertLogs(level="DEBUG") as logs:
|
||||||
|
daemon.set_state(leld.State.RECORDING)
|
||||||
|
|
||||||
|
self.assertTrue(
|
||||||
|
any("DEBUG:root:state: idle -> recording" in line for line in logs.output)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class LockTests(unittest.TestCase):
|
class LockTests(unittest.TestCase):
|
||||||
def test_lock_rejects_second_instance(self):
|
def test_lock_rejects_second_instance(self):
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue