Revert "Add pipeline engine and remove legacy compatibility paths"
This commit is contained in:
parent
e221d49020
commit
5b38cc7dcd
18 changed files with 399 additions and 1523 deletions
88
README.md
88
README.md
|
|
@ -80,7 +80,8 @@ Create `~/.config/aman/config.json` (or let `aman` create it automatically on fi
|
||||||
{ "from": "docker", "to": "Docker" }
|
{ "from": "docker", "to": "Docker" }
|
||||||
],
|
],
|
||||||
"terms": ["Systemd", "Kubernetes"]
|
"terms": ["Systemd", "Kubernetes"]
|
||||||
}
|
},
|
||||||
|
"domain_inference": { "enabled": true }
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
@ -91,9 +92,6 @@ Hotkey notes:
|
||||||
|
|
||||||
- Use one key plus optional modifiers (for example `Cmd+m`, `Super+m`, `Ctrl+space`).
|
- Use one key plus optional modifiers (for example `Cmd+m`, `Super+m`, `Ctrl+space`).
|
||||||
- `Super` and `Cmd` are equivalent aliases for the same modifier.
|
- `Super` and `Cmd` are equivalent aliases for the same modifier.
|
||||||
- Invalid hotkey syntax in config prevents startup/reload.
|
|
||||||
- When `~/.config/aman/pipelines.py` exists, hotkeys come from `HOTKEY_PIPELINES`.
|
|
||||||
- `daemon.hotkey` is used as the fallback/default hotkey only when no pipelines file is present.
|
|
||||||
|
|
||||||
AI cleanup is always enabled and uses the locked local Llama-3.2-3B GGUF model
|
AI cleanup is always enabled and uses the locked local Llama-3.2-3B GGUF model
|
||||||
downloaded to `~/.cache/aman/models/` during daemon initialization.
|
downloaded to `~/.cache/aman/models/` during daemon initialization.
|
||||||
|
|
@ -109,6 +107,11 @@ Vocabulary correction:
|
||||||
- Wildcards are intentionally rejected (`*`, `?`, `[`, `]`, `{`, `}`) to avoid ambiguous rules.
|
- Wildcards are intentionally rejected (`*`, `?`, `[`, `]`, `{`, `}`) to avoid ambiguous rules.
|
||||||
- Rules are deduplicated case-insensitively; conflicting replacements are rejected.
|
- Rules are deduplicated case-insensitively; conflicting replacements are rejected.
|
||||||
|
|
||||||
|
Domain inference:
|
||||||
|
|
||||||
|
- Domain context is advisory only and is used to improve cleanup prompts.
|
||||||
|
- When confidence is low, it falls back to `general` context.
|
||||||
|
|
||||||
STT hinting:
|
STT hinting:
|
||||||
|
|
||||||
- Vocabulary is passed to Whisper as `hotwords`/`initial_prompt` only when those
|
- Vocabulary is passed to Whisper as `hotwords`/`initial_prompt` only when those
|
||||||
|
|
@ -131,12 +134,6 @@ systemctl --user enable --now aman
|
||||||
- Press it again to stop and run STT.
|
- Press it again to stop and run STT.
|
||||||
- Press `Esc` while recording to cancel without processing.
|
- Press `Esc` while recording to cancel without processing.
|
||||||
- Transcript contents are logged only when `-v/--verbose` is used.
|
- Transcript contents are logged only when `-v/--verbose` is used.
|
||||||
- Config changes are hot-reloaded automatically (polled every 1 second).
|
|
||||||
- `~/.config/aman/pipelines.py` changes are hot-reloaded automatically (polled every 1 second).
|
|
||||||
- Send `SIGHUP` to force an immediate reload of config and pipelines:
|
|
||||||
`systemctl --user kill -s HUP aman` (or send `HUP` to the process directly).
|
|
||||||
- Reloads are applied when the daemon is idle; invalid updates are rejected and the last valid config stays active.
|
|
||||||
- Reload success/failure is logged, and desktop notifications are shown when available.
|
|
||||||
|
|
||||||
Wayland note:
|
Wayland note:
|
||||||
|
|
||||||
|
|
@ -152,77 +149,6 @@ AI processing:
|
||||||
|
|
||||||
- Local llama.cpp model only (no remote provider configuration).
|
- Local llama.cpp model only (no remote provider configuration).
|
||||||
|
|
||||||
## Pipelines API
|
|
||||||
|
|
||||||
`aman` is split into:
|
|
||||||
|
|
||||||
- shell daemon: hotkeys, recording/cancel, and desktop injection
|
|
||||||
- pipeline engine: `lib.transcribe(...)` and `lib.llm(...)`
|
|
||||||
- pipeline implementation: Python callables mapped per hotkey
|
|
||||||
|
|
||||||
Pipeline file path:
|
|
||||||
|
|
||||||
- `~/.config/aman/pipelines.py`
|
|
||||||
- You can start from [`pipelines.example.py`](./pipelines.example.py).
|
|
||||||
- If `pipelines.py` is missing, `aman` uses a built-in reference pipeline bound to `daemon.hotkey`.
|
|
||||||
- If `pipelines.py` exists but is invalid, startup fails fast.
|
|
||||||
- Pipelines are hot-reloaded automatically when the module file changes.
|
|
||||||
- Send `SIGHUP` to force an immediate reload of both config and pipelines.
|
|
||||||
|
|
||||||
Expected module exports:
|
|
||||||
|
|
||||||
```python
|
|
||||||
HOTKEY_PIPELINES = {
|
|
||||||
"Super+m": my_pipeline,
|
|
||||||
"Super+Shift+m": caps_pipeline,
|
|
||||||
}
|
|
||||||
|
|
||||||
PIPELINE_OPTIONS = {
|
|
||||||
"Super+Shift+m": {"failure_policy": "strict"}, # optional
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
Pipeline callable signature:
|
|
||||||
|
|
||||||
```python
|
|
||||||
def my_pipeline(audio, lib) -> str:
|
|
||||||
text = lib.transcribe(audio)
|
|
||||||
context = lib.llm(
|
|
||||||
system_prompt="context system prompt",
|
|
||||||
user_prompt=f"Transcript: {text}",
|
|
||||||
)
|
|
||||||
out = lib.llm(
|
|
||||||
system_prompt="amanuensis prompt",
|
|
||||||
user_prompt=f"context={context}\ntext={text}",
|
|
||||||
)
|
|
||||||
return out
|
|
||||||
```
|
|
||||||
|
|
||||||
`lib` API:
|
|
||||||
|
|
||||||
- `lib.transcribe(audio, hints=None, whisper_opts=None) -> str`
|
|
||||||
- `lib.llm(system_prompt=..., user_prompt=..., llm_opts=None) -> str`
|
|
||||||
|
|
||||||
Failure policy options:
|
|
||||||
|
|
||||||
- `best_effort` (default): pipeline errors return empty output
|
|
||||||
- `strict`: pipeline errors abort the current run
|
|
||||||
|
|
||||||
Validation:
|
|
||||||
|
|
||||||
- `HOTKEY_PIPELINES` must be a non-empty dictionary.
|
|
||||||
- Every hotkey key must be a non-empty string.
|
|
||||||
- Every pipeline value must be callable.
|
|
||||||
- `PIPELINE_OPTIONS` must be a dictionary when provided.
|
|
||||||
|
|
||||||
Reference behavior:
|
|
||||||
|
|
||||||
- The built-in fallback pipeline (used when `pipelines.py` is missing) uses `lib.llm(...)` twice:
|
|
||||||
- first to infer context
|
|
||||||
- second to run the amanuensis rewrite
|
|
||||||
- The second pass requests JSON output and expects `{"cleaned_text": "..."}`.
|
|
||||||
- Deterministic dictionary replacements are then applied as part of that reference implementation.
|
|
||||||
|
|
||||||
Control:
|
Control:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
|
|
||||||
|
|
@ -1,50 +0,0 @@
|
||||||
import json
|
|
||||||
|
|
||||||
|
|
||||||
CONTEXT_PROMPT = (
|
|
||||||
"Return a concise plain-text context hint (max 12 words) "
|
|
||||||
"for the transcript domain and style."
|
|
||||||
)
|
|
||||||
|
|
||||||
AMANUENSIS_PROMPT = (
|
|
||||||
"Rewrite the transcript into clean prose without changing intent. "
|
|
||||||
"Return only the final text."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def default_pipeline(audio, lib) -> str:
|
|
||||||
text = lib.transcribe(audio)
|
|
||||||
if not text:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
context = lib.llm(
|
|
||||||
system_prompt=CONTEXT_PROMPT,
|
|
||||||
user_prompt=json.dumps({"transcript": text}, ensure_ascii=False),
|
|
||||||
llm_opts={"temperature": 0.0},
|
|
||||||
).strip()
|
|
||||||
|
|
||||||
output = lib.llm(
|
|
||||||
system_prompt=AMANUENSIS_PROMPT,
|
|
||||||
user_prompt=json.dumps(
|
|
||||||
{"transcript": text, "context": context},
|
|
||||||
ensure_ascii=False,
|
|
||||||
),
|
|
||||||
llm_opts={"temperature": 0.0},
|
|
||||||
)
|
|
||||||
return output.strip()
|
|
||||||
|
|
||||||
|
|
||||||
def caps_pipeline(audio, lib) -> str:
|
|
||||||
return lib.transcribe(audio).upper()
|
|
||||||
|
|
||||||
|
|
||||||
HOTKEY_PIPELINES = {
|
|
||||||
"Super+m": default_pipeline,
|
|
||||||
"Super+Shift+m": caps_pipeline,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
PIPELINE_OPTIONS = {
|
|
||||||
"Super+m": {"failure_policy": "best_effort"},
|
|
||||||
"Super+Shift+m": {"failure_policy": "strict"},
|
|
||||||
}
|
|
||||||
|
|
@ -76,41 +76,18 @@ class LlamaProcessor:
|
||||||
if cleaned_dictionary:
|
if cleaned_dictionary:
|
||||||
request_payload["dictionary"] = cleaned_dictionary
|
request_payload["dictionary"] = cleaned_dictionary
|
||||||
|
|
||||||
content = self.chat(
|
|
||||||
system_prompt=SYSTEM_PROMPT,
|
|
||||||
user_prompt=json.dumps(request_payload, ensure_ascii=False),
|
|
||||||
llm_opts={"temperature": 0.0, "response_format": {"type": "json_object"}},
|
|
||||||
)
|
|
||||||
return _extract_cleaned_text_from_raw(content)
|
|
||||||
|
|
||||||
def chat(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
system_prompt: str,
|
|
||||||
user_prompt: str,
|
|
||||||
llm_opts: dict[str, Any] | None = None,
|
|
||||||
) -> str:
|
|
||||||
opts = dict(llm_opts or {})
|
|
||||||
temperature = float(opts.pop("temperature", 0.0))
|
|
||||||
response_format = opts.pop("response_format", None)
|
|
||||||
if opts:
|
|
||||||
unknown = ", ".join(sorted(opts.keys()))
|
|
||||||
raise ValueError(f"unsupported llm options: {unknown}")
|
|
||||||
|
|
||||||
kwargs: dict[str, Any] = {
|
kwargs: dict[str, Any] = {
|
||||||
"messages": [
|
"messages": [
|
||||||
{"role": "system", "content": system_prompt},
|
{"role": "system", "content": SYSTEM_PROMPT},
|
||||||
{"role": "user", "content": user_prompt},
|
{"role": "user", "content": json.dumps(request_payload, ensure_ascii=False)},
|
||||||
],
|
],
|
||||||
"temperature": temperature,
|
"temperature": 0.0,
|
||||||
}
|
}
|
||||||
if response_format is not None and _supports_response_format(
|
if _supports_response_format(self.client.create_chat_completion):
|
||||||
self.client.create_chat_completion
|
kwargs["response_format"] = {"type": "json_object"}
|
||||||
):
|
|
||||||
kwargs["response_format"] = response_format
|
|
||||||
|
|
||||||
response = self.client.create_chat_completion(**kwargs)
|
response = self.client.create_chat_completion(**kwargs)
|
||||||
return _extract_chat_text(response)
|
return _extract_cleaned_text(response)
|
||||||
|
|
||||||
|
|
||||||
def ensure_model():
|
def ensure_model():
|
||||||
|
|
@ -171,10 +148,6 @@ def _extract_chat_text(payload: Any) -> str:
|
||||||
|
|
||||||
def _extract_cleaned_text(payload: Any) -> str:
|
def _extract_cleaned_text(payload: Any) -> str:
|
||||||
raw = _extract_chat_text(payload)
|
raw = _extract_chat_text(payload)
|
||||||
return _extract_cleaned_text_from_raw(raw)
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_cleaned_text_from_raw(raw: str) -> str:
|
|
||||||
try:
|
try:
|
||||||
parsed = json.loads(raw)
|
parsed = json.loads(raw)
|
||||||
except json.JSONDecodeError as exc:
|
except json.JSONDecodeError as exc:
|
||||||
|
|
|
||||||
678
src/aman.py
678
src/aman.py
|
|
@ -3,7 +3,6 @@ from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import errno
|
import errno
|
||||||
import hashlib
|
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
|
@ -13,21 +12,12 @@ import sys
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable
|
from typing import Any
|
||||||
|
|
||||||
from aiprocess import LlamaProcessor, SYSTEM_PROMPT
|
from aiprocess import LlamaProcessor
|
||||||
from config import Config, load, redacted_dict
|
from config import Config, load, redacted_dict
|
||||||
from constants import (
|
from constants import RECORD_TIMEOUT_SEC, STT_LANGUAGE
|
||||||
CONFIG_RELOAD_POLL_INTERVAL_SEC,
|
|
||||||
DEFAULT_CONFIG_PATH,
|
|
||||||
RECORD_TIMEOUT_SEC,
|
|
||||||
STT_LANGUAGE,
|
|
||||||
)
|
|
||||||
from desktop import get_desktop_adapter
|
from desktop import get_desktop_adapter
|
||||||
from engine import Engine, PipelineBinding, PipelineLib, PipelineOptions
|
|
||||||
from notify import DesktopNotifier
|
|
||||||
from pipelines_runtime import DEFAULT_PIPELINES_PATH, fingerprint as pipelines_fingerprint
|
|
||||||
from pipelines_runtime import load_bindings
|
|
||||||
from recorder import start_recording as start_audio_recording
|
from recorder import start_recording as start_audio_recording
|
||||||
from recorder import stop_recording as stop_audio_recording
|
from recorder import stop_recording as stop_audio_recording
|
||||||
from vocabulary import VocabularyEngine
|
from vocabulary import VocabularyEngine
|
||||||
|
|
@ -65,52 +55,6 @@ def _compute_type(device: str) -> str:
|
||||||
return "int8"
|
return "int8"
|
||||||
|
|
||||||
|
|
||||||
def _replacement_view(cfg: Config) -> list[tuple[str, str]]:
|
|
||||||
return [(item.source, item.target) for item in cfg.vocabulary.replacements]
|
|
||||||
|
|
||||||
|
|
||||||
def _config_fields_changed(current: Config, candidate: Config) -> list[str]:
|
|
||||||
changed: list[str] = []
|
|
||||||
if current.daemon.hotkey != candidate.daemon.hotkey:
|
|
||||||
changed.append("daemon.hotkey")
|
|
||||||
if current.recording.input != candidate.recording.input:
|
|
||||||
changed.append("recording.input")
|
|
||||||
if current.stt.model != candidate.stt.model:
|
|
||||||
changed.append("stt.model")
|
|
||||||
if current.stt.device != candidate.stt.device:
|
|
||||||
changed.append("stt.device")
|
|
||||||
if current.injection.backend != candidate.injection.backend:
|
|
||||||
changed.append("injection.backend")
|
|
||||||
if (
|
|
||||||
current.injection.remove_transcription_from_clipboard
|
|
||||||
!= candidate.injection.remove_transcription_from_clipboard
|
|
||||||
):
|
|
||||||
changed.append("injection.remove_transcription_from_clipboard")
|
|
||||||
if _replacement_view(current) != _replacement_view(candidate):
|
|
||||||
changed.append("vocabulary.replacements")
|
|
||||||
if current.vocabulary.terms != candidate.vocabulary.terms:
|
|
||||||
changed.append("vocabulary.terms")
|
|
||||||
return changed
|
|
||||||
|
|
||||||
|
|
||||||
CONTEXT_SYSTEM_PROMPT = (
|
|
||||||
"You infer concise writing context from transcript text.\n"
|
|
||||||
"Return a short plain-text hint (max 12 words) that describes the likely domain and style.\n"
|
|
||||||
"Do not add punctuation, quotes, XML tags, or markdown."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_cleaned_text_from_llm_raw(raw: str) -> str:
|
|
||||||
parsed = json.loads(raw)
|
|
||||||
if isinstance(parsed, str):
|
|
||||||
return parsed
|
|
||||||
if isinstance(parsed, dict):
|
|
||||||
cleaned_text = parsed.get("cleaned_text")
|
|
||||||
if isinstance(cleaned_text, str):
|
|
||||||
return cleaned_text
|
|
||||||
raise RuntimeError("unexpected ai output format: missing cleaned_text")
|
|
||||||
|
|
||||||
|
|
||||||
class Daemon:
|
class Daemon:
|
||||||
def __init__(self, cfg: Config, desktop, *, verbose: bool = False):
|
def __init__(self, cfg: Config, desktop, *, verbose: bool = False):
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
|
|
@ -122,34 +66,17 @@ class Daemon:
|
||||||
self.stream = None
|
self.stream = None
|
||||||
self.record = None
|
self.record = None
|
||||||
self.timer: threading.Timer | None = None
|
self.timer: threading.Timer | None = None
|
||||||
self._config_apply_in_progress = False
|
self.model = _build_whisper_model(
|
||||||
self._active_hotkey: str | None = None
|
cfg.stt.model,
|
||||||
self._pending_worker_hotkey: str | None = None
|
cfg.stt.device,
|
||||||
self._pipeline_bindings: dict[str, PipelineBinding] = {}
|
)
|
||||||
|
|
||||||
self.model = _build_whisper_model(cfg.stt.model, cfg.stt.device)
|
|
||||||
logging.info("initializing ai processor")
|
logging.info("initializing ai processor")
|
||||||
self.ai_processor = LlamaProcessor(verbose=self.verbose)
|
self.ai_processor = LlamaProcessor(verbose=self.verbose)
|
||||||
logging.info("ai processor ready")
|
logging.info("ai processor ready")
|
||||||
self.log_transcript = verbose
|
self.log_transcript = verbose
|
||||||
self.vocabulary = VocabularyEngine(cfg.vocabulary)
|
self.vocabulary = VocabularyEngine(cfg.vocabulary, cfg.domain_inference)
|
||||||
self._stt_hint_kwargs_cache: dict[str, Any] | None = None
|
self._stt_hint_kwargs_cache: dict[str, Any] | None = None
|
||||||
|
|
||||||
self.lib = PipelineLib(
|
|
||||||
transcribe_fn=self._transcribe_with_options,
|
|
||||||
llm_fn=self._llm_with_options,
|
|
||||||
)
|
|
||||||
self.engine = Engine(self.lib)
|
|
||||||
self.set_pipeline_bindings(
|
|
||||||
{
|
|
||||||
cfg.daemon.hotkey: PipelineBinding(
|
|
||||||
hotkey=cfg.daemon.hotkey,
|
|
||||||
handler=self.build_reference_pipeline(),
|
|
||||||
options=PipelineOptions(failure_policy="best_effort"),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
def set_state(self, state: str):
|
def set_state(self, state: str):
|
||||||
with self.lock:
|
with self.lock:
|
||||||
prev = self.state
|
prev = self.state
|
||||||
|
|
@ -166,74 +93,16 @@ class Daemon:
|
||||||
def request_shutdown(self):
|
def request_shutdown(self):
|
||||||
self._shutdown_requested.set()
|
self._shutdown_requested.set()
|
||||||
|
|
||||||
def current_hotkeys(self) -> list[str]:
|
def toggle(self):
|
||||||
with self.lock:
|
|
||||||
return list(self._pipeline_bindings.keys())
|
|
||||||
|
|
||||||
def set_pipeline_bindings(self, bindings: dict[str, PipelineBinding]) -> None:
|
|
||||||
if not bindings:
|
|
||||||
raise ValueError("at least one pipeline binding is required")
|
|
||||||
with self.lock:
|
|
||||||
self._pipeline_bindings = dict(bindings)
|
|
||||||
if self._active_hotkey and self._active_hotkey not in self._pipeline_bindings:
|
|
||||||
self._active_hotkey = None
|
|
||||||
|
|
||||||
def build_hotkey_callbacks(
|
|
||||||
self,
|
|
||||||
hotkeys: list[str],
|
|
||||||
*,
|
|
||||||
dry_run: bool,
|
|
||||||
) -> dict[str, Callable[[], None]]:
|
|
||||||
callbacks: dict[str, Callable[[], None]] = {}
|
|
||||||
for hotkey in hotkeys:
|
|
||||||
if dry_run:
|
|
||||||
callbacks[hotkey] = (lambda hk=hotkey: logging.info("hotkey pressed: %s (dry-run)", hk))
|
|
||||||
else:
|
|
||||||
callbacks[hotkey] = (lambda hk=hotkey: self.handle_hotkey(hk))
|
|
||||||
return callbacks
|
|
||||||
|
|
||||||
def apply_pipeline_bindings(
|
|
||||||
self,
|
|
||||||
bindings: dict[str, PipelineBinding],
|
|
||||||
callbacks: dict[str, Callable[[], None]],
|
|
||||||
) -> tuple[str, str]:
|
|
||||||
with self.lock:
|
|
||||||
if self._shutdown_requested.is_set():
|
|
||||||
return "deferred", "shutdown in progress"
|
|
||||||
if self._config_apply_in_progress:
|
|
||||||
return "deferred", "reload in progress"
|
|
||||||
if self.state != State.IDLE:
|
|
||||||
return "deferred", f"daemon is busy ({self.state})"
|
|
||||||
self._config_apply_in_progress = True
|
|
||||||
try:
|
|
||||||
self.desktop.set_hotkeys(callbacks)
|
|
||||||
with self.lock:
|
|
||||||
if self._shutdown_requested.is_set():
|
|
||||||
return "deferred", "shutdown in progress"
|
|
||||||
if self.state != State.IDLE:
|
|
||||||
return "deferred", f"daemon is busy ({self.state})"
|
|
||||||
self._pipeline_bindings = dict(bindings)
|
|
||||||
return "applied", ""
|
|
||||||
except Exception as exc:
|
|
||||||
return "error", str(exc)
|
|
||||||
finally:
|
|
||||||
with self.lock:
|
|
||||||
self._config_apply_in_progress = False
|
|
||||||
|
|
||||||
def handle_hotkey(self, hotkey: str):
|
|
||||||
should_stop = False
|
should_stop = False
|
||||||
with self.lock:
|
with self.lock:
|
||||||
if self._shutdown_requested.is_set():
|
if self._shutdown_requested.is_set():
|
||||||
logging.info("shutdown in progress, trigger ignored")
|
logging.info("shutdown in progress, trigger ignored")
|
||||||
return
|
return
|
||||||
if self._config_apply_in_progress:
|
|
||||||
logging.info("reload in progress, trigger ignored")
|
|
||||||
return
|
|
||||||
if self.state == State.IDLE:
|
if self.state == State.IDLE:
|
||||||
self._active_hotkey = hotkey
|
|
||||||
self._start_recording_locked()
|
self._start_recording_locked()
|
||||||
return
|
return
|
||||||
if self.state == State.RECORDING and self._active_hotkey == hotkey:
|
if self.state == State.RECORDING:
|
||||||
should_stop = True
|
should_stop = True
|
||||||
else:
|
else:
|
||||||
logging.info("busy (%s), trigger ignored", self.state)
|
logging.info("busy (%s), trigger ignored", self.state)
|
||||||
|
|
@ -244,9 +113,6 @@ class Daemon:
|
||||||
if self.state != State.IDLE:
|
if self.state != State.IDLE:
|
||||||
logging.info("busy (%s), trigger ignored", self.state)
|
logging.info("busy (%s), trigger ignored", self.state)
|
||||||
return
|
return
|
||||||
if self._config_apply_in_progress:
|
|
||||||
logging.info("reload in progress, trigger ignored")
|
|
||||||
return
|
|
||||||
try:
|
try:
|
||||||
stream, record = start_audio_recording(self.cfg.recording.input)
|
stream, record = start_audio_recording(self.cfg.recording.input)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
|
|
@ -267,17 +133,10 @@ class Daemon:
|
||||||
def _timeout_stop(self):
|
def _timeout_stop(self):
|
||||||
self.stop_recording(trigger="timeout")
|
self.stop_recording(trigger="timeout")
|
||||||
|
|
||||||
def _start_stop_worker(
|
def _start_stop_worker(self, stream: Any, record: Any, trigger: str, process_audio: bool):
|
||||||
self,
|
|
||||||
stream: Any,
|
|
||||||
record: Any,
|
|
||||||
trigger: str,
|
|
||||||
process_audio: bool,
|
|
||||||
):
|
|
||||||
active_hotkey = self._pending_worker_hotkey or self.cfg.daemon.hotkey
|
|
||||||
threading.Thread(
|
threading.Thread(
|
||||||
target=self._stop_and_process,
|
target=self._stop_and_process,
|
||||||
args=(stream, record, trigger, process_audio, active_hotkey),
|
args=(stream, record, trigger, process_audio),
|
||||||
daemon=True,
|
daemon=True,
|
||||||
).start()
|
).start()
|
||||||
|
|
||||||
|
|
@ -286,8 +145,6 @@ class Daemon:
|
||||||
return None
|
return None
|
||||||
stream = self.stream
|
stream = self.stream
|
||||||
record = self.record
|
record = self.record
|
||||||
active_hotkey = self._active_hotkey or self.cfg.daemon.hotkey
|
|
||||||
self._active_hotkey = None
|
|
||||||
self.stream = None
|
self.stream = None
|
||||||
self.record = None
|
self.record = None
|
||||||
if self.timer:
|
if self.timer:
|
||||||
|
|
@ -301,17 +158,9 @@ class Daemon:
|
||||||
logging.warning("recording resources are unavailable during stop")
|
logging.warning("recording resources are unavailable during stop")
|
||||||
self.state = State.IDLE
|
self.state = State.IDLE
|
||||||
return None
|
return None
|
||||||
return stream, record, active_hotkey
|
return stream, record
|
||||||
|
|
||||||
def _stop_and_process(
|
def _stop_and_process(self, stream: Any, record: Any, trigger: str, process_audio: bool):
|
||||||
self,
|
|
||||||
stream: Any,
|
|
||||||
record: Any,
|
|
||||||
trigger: str,
|
|
||||||
process_audio: bool,
|
|
||||||
active_hotkey: str | None = None,
|
|
||||||
):
|
|
||||||
hotkey = active_hotkey or self.cfg.daemon.hotkey
|
|
||||||
logging.info("stopping recording (%s)", trigger)
|
logging.info("stopping recording (%s)", trigger)
|
||||||
try:
|
try:
|
||||||
audio = stop_audio_recording(stream, record)
|
audio = stop_audio_recording(stream, record)
|
||||||
|
|
@ -330,17 +179,44 @@ class Daemon:
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logging.info("pipeline started (%s)", hotkey)
|
self.set_state(State.STT)
|
||||||
text = self._run_pipeline(audio, hotkey).strip()
|
logging.info("stt started")
|
||||||
|
text = self._transcribe(audio)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logging.error("pipeline failed: %s", exc)
|
logging.error("stt failed: %s", exc)
|
||||||
self.set_state(State.IDLE)
|
self.set_state(State.IDLE)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
text = (text or "").strip()
|
||||||
if not text:
|
if not text:
|
||||||
self.set_state(State.IDLE)
|
self.set_state(State.IDLE)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if self.log_transcript:
|
||||||
|
logging.debug("stt: %s", text)
|
||||||
|
else:
|
||||||
|
logging.info("stt produced %d chars", len(text))
|
||||||
|
|
||||||
|
domain = self.vocabulary.infer_domain(text)
|
||||||
|
if not self._shutdown_requested.is_set():
|
||||||
|
self.set_state(State.PROCESSING)
|
||||||
|
logging.info("ai processing started")
|
||||||
|
try:
|
||||||
|
processor = self._get_ai_processor()
|
||||||
|
ai_text = processor.process(
|
||||||
|
text,
|
||||||
|
lang=STT_LANGUAGE,
|
||||||
|
dictionary_context=self.vocabulary.build_ai_dictionary_context(),
|
||||||
|
domain_name=domain.name,
|
||||||
|
domain_confidence=domain.confidence,
|
||||||
|
)
|
||||||
|
if ai_text and ai_text.strip():
|
||||||
|
text = ai_text.strip()
|
||||||
|
except Exception as exc:
|
||||||
|
logging.error("ai process failed: %s", exc)
|
||||||
|
|
||||||
|
text = self.vocabulary.apply_deterministic_replacements(text).strip()
|
||||||
|
|
||||||
if self.log_transcript:
|
if self.log_transcript:
|
||||||
logging.debug("processed: %s", text)
|
logging.debug("processed: %s", text)
|
||||||
else:
|
else:
|
||||||
|
|
@ -366,23 +242,13 @@ class Daemon:
|
||||||
finally:
|
finally:
|
||||||
self.set_state(State.IDLE)
|
self.set_state(State.IDLE)
|
||||||
|
|
||||||
def _run_pipeline(self, audio: Any, hotkey: str) -> str:
|
|
||||||
with self.lock:
|
|
||||||
binding = self._pipeline_bindings.get(hotkey)
|
|
||||||
if binding is None and self._pipeline_bindings:
|
|
||||||
binding = next(iter(self._pipeline_bindings.values()))
|
|
||||||
if binding is None:
|
|
||||||
return ""
|
|
||||||
return self.engine.run(binding, audio)
|
|
||||||
|
|
||||||
def stop_recording(self, *, trigger: str = "user", process_audio: bool = True):
|
def stop_recording(self, *, trigger: str = "user", process_audio: bool = True):
|
||||||
payload = None
|
payload = None
|
||||||
with self.lock:
|
with self.lock:
|
||||||
payload = self._begin_stop_locked()
|
payload = self._begin_stop_locked()
|
||||||
if payload is None:
|
if payload is None:
|
||||||
return
|
return
|
||||||
stream, record, active_hotkey = payload
|
stream, record = payload
|
||||||
self._pending_worker_hotkey = active_hotkey
|
|
||||||
self._start_stop_worker(stream, record, trigger, process_audio)
|
self._start_stop_worker(stream, record, trigger, process_audio)
|
||||||
|
|
||||||
def cancel_recording(self):
|
def cancel_recording(self):
|
||||||
|
|
@ -404,129 +270,40 @@ class Daemon:
|
||||||
time.sleep(0.05)
|
time.sleep(0.05)
|
||||||
return self.get_state() == State.IDLE
|
return self.get_state() == State.IDLE
|
||||||
|
|
||||||
def build_reference_pipeline(self) -> Callable[[Any, PipelineLib], str]:
|
|
||||||
def _reference_pipeline(audio: Any, lib: PipelineLib) -> str:
|
|
||||||
text = (lib.transcribe(audio) or "").strip()
|
|
||||||
if not text:
|
|
||||||
return ""
|
|
||||||
dictionary_context = self.vocabulary.build_ai_dictionary_context()
|
|
||||||
context_prompt = {
|
|
||||||
"transcript": text,
|
|
||||||
"dictionary": dictionary_context,
|
|
||||||
}
|
|
||||||
context = (
|
|
||||||
lib.llm(
|
|
||||||
system_prompt=CONTEXT_SYSTEM_PROMPT,
|
|
||||||
user_prompt=json.dumps(context_prompt, ensure_ascii=False),
|
|
||||||
llm_opts={"temperature": 0.0},
|
|
||||||
)
|
|
||||||
.strip()
|
|
||||||
)
|
|
||||||
payload: dict[str, Any] = {
|
|
||||||
"language": STT_LANGUAGE,
|
|
||||||
"transcript": text,
|
|
||||||
"context": context,
|
|
||||||
}
|
|
||||||
if dictionary_context:
|
|
||||||
payload["dictionary"] = dictionary_context
|
|
||||||
ai_raw = lib.llm(
|
|
||||||
system_prompt=SYSTEM_PROMPT,
|
|
||||||
user_prompt=json.dumps(payload, ensure_ascii=False),
|
|
||||||
llm_opts={"temperature": 0.0, "response_format": {"type": "json_object"}},
|
|
||||||
)
|
|
||||||
cleaned_text = _extract_cleaned_text_from_llm_raw(ai_raw)
|
|
||||||
return self.vocabulary.apply_deterministic_replacements(cleaned_text).strip()
|
|
||||||
|
|
||||||
return _reference_pipeline
|
|
||||||
|
|
||||||
def _transcribe(self, audio) -> str:
|
def _transcribe(self, audio) -> str:
|
||||||
return self._transcribe_with_options(audio)
|
|
||||||
|
|
||||||
def _transcribe_with_options(
|
|
||||||
self,
|
|
||||||
audio: Any,
|
|
||||||
*,
|
|
||||||
hints: list[str] | None = None,
|
|
||||||
whisper_opts: dict[str, Any] | None = None,
|
|
||||||
) -> str:
|
|
||||||
kwargs: dict[str, Any] = {
|
kwargs: dict[str, Any] = {
|
||||||
"language": STT_LANGUAGE,
|
"language": STT_LANGUAGE,
|
||||||
"vad_filter": True,
|
"vad_filter": True,
|
||||||
}
|
}
|
||||||
kwargs.update(self._stt_hint_kwargs(hints=hints))
|
kwargs.update(self._stt_hint_kwargs())
|
||||||
if whisper_opts:
|
|
||||||
kwargs.update(whisper_opts)
|
|
||||||
segments, _info = self.model.transcribe(audio, **kwargs)
|
segments, _info = self.model.transcribe(audio, **kwargs)
|
||||||
parts = []
|
parts = []
|
||||||
for seg in segments:
|
for seg in segments:
|
||||||
text = (seg.text or "").strip()
|
text = (seg.text or "").strip()
|
||||||
if text:
|
if text:
|
||||||
parts.append(text)
|
parts.append(text)
|
||||||
out = " ".join(parts).strip()
|
return " ".join(parts).strip()
|
||||||
if self.log_transcript:
|
|
||||||
logging.debug("stt: %s", out)
|
|
||||||
else:
|
|
||||||
logging.info("stt produced %d chars", len(out))
|
|
||||||
return out
|
|
||||||
|
|
||||||
def _llm_with_options(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
system_prompt: str,
|
|
||||||
user_prompt: str,
|
|
||||||
llm_opts: dict[str, Any] | None = None,
|
|
||||||
) -> str:
|
|
||||||
if self.get_state() != State.PROCESSING:
|
|
||||||
self.set_state(State.PROCESSING)
|
|
||||||
logging.info("llm processing started")
|
|
||||||
processor = self._get_ai_processor()
|
|
||||||
return processor.chat(
|
|
||||||
system_prompt=system_prompt,
|
|
||||||
user_prompt=user_prompt,
|
|
||||||
llm_opts=llm_opts,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_ai_processor(self) -> LlamaProcessor:
|
def _get_ai_processor(self) -> LlamaProcessor:
|
||||||
if self.ai_processor is None:
|
if self.ai_processor is None:
|
||||||
raise RuntimeError("ai processor is not initialized")
|
raise RuntimeError("ai processor is not initialized")
|
||||||
return self.ai_processor
|
return self.ai_processor
|
||||||
|
|
||||||
def _stt_hint_kwargs(self, *, hints: list[str] | None = None) -> dict[str, Any]:
|
def _stt_hint_kwargs(self) -> dict[str, Any]:
|
||||||
if not hints and self._stt_hint_kwargs_cache is not None:
|
if self._stt_hint_kwargs_cache is not None:
|
||||||
return self._stt_hint_kwargs_cache
|
return self._stt_hint_kwargs_cache
|
||||||
|
|
||||||
hotwords, initial_prompt = self.vocabulary.build_stt_hints()
|
hotwords, initial_prompt = self.vocabulary.build_stt_hints()
|
||||||
extra_hints = [item.strip() for item in (hints or []) if isinstance(item, str) and item.strip()]
|
|
||||||
if extra_hints:
|
|
||||||
words = [item.strip() for item in hotwords.split(",") if item.strip()]
|
|
||||||
words.extend(extra_hints)
|
|
||||||
deduped: list[str] = []
|
|
||||||
seen: set[str] = set()
|
|
||||||
for item in words:
|
|
||||||
key = item.casefold()
|
|
||||||
if key in seen:
|
|
||||||
continue
|
|
||||||
seen.add(key)
|
|
||||||
deduped.append(item)
|
|
||||||
merged = ", ".join(deduped)
|
|
||||||
hotwords = merged[:1024]
|
|
||||||
if hotwords:
|
|
||||||
initial_prompt = f"Preferred vocabulary: {hotwords}"[:600]
|
|
||||||
|
|
||||||
if not hotwords and not initial_prompt:
|
if not hotwords and not initial_prompt:
|
||||||
if not hints:
|
self._stt_hint_kwargs_cache = {}
|
||||||
self._stt_hint_kwargs_cache = {}
|
return self._stt_hint_kwargs_cache
|
||||||
return self._stt_hint_kwargs_cache
|
|
||||||
return {}
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
signature = inspect.signature(self.model.transcribe)
|
signature = inspect.signature(self.model.transcribe)
|
||||||
except (TypeError, ValueError):
|
except (TypeError, ValueError):
|
||||||
logging.debug("stt signature inspection failed; skipping hints")
|
logging.debug("stt signature inspection failed; skipping hints")
|
||||||
if not hints:
|
self._stt_hint_kwargs_cache = {}
|
||||||
self._stt_hint_kwargs_cache = {}
|
return self._stt_hint_kwargs_cache
|
||||||
return self._stt_hint_kwargs_cache
|
|
||||||
return {}
|
|
||||||
|
|
||||||
params = signature.parameters
|
params = signature.parameters
|
||||||
kwargs: dict[str, Any] = {}
|
kwargs: dict[str, Any] = {}
|
||||||
|
|
@ -536,288 +313,8 @@ class Daemon:
|
||||||
kwargs["initial_prompt"] = initial_prompt
|
kwargs["initial_prompt"] = initial_prompt
|
||||||
if not kwargs:
|
if not kwargs:
|
||||||
logging.debug("stt hint arguments are not supported by this whisper runtime")
|
logging.debug("stt hint arguments are not supported by this whisper runtime")
|
||||||
if not hints:
|
self._stt_hint_kwargs_cache = kwargs
|
||||||
self._stt_hint_kwargs_cache = kwargs
|
return self._stt_hint_kwargs_cache
|
||||||
return self._stt_hint_kwargs_cache
|
|
||||||
return kwargs
|
|
||||||
|
|
||||||
def try_apply_config(self, candidate: Config) -> tuple[str, list[str], str]:
|
|
||||||
with self.lock:
|
|
||||||
if self._shutdown_requested.is_set():
|
|
||||||
return "deferred", [], "shutdown in progress"
|
|
||||||
if self._config_apply_in_progress:
|
|
||||||
return "deferred", [], "config reload already in progress"
|
|
||||||
if self.state != State.IDLE:
|
|
||||||
return "deferred", [], f"daemon is busy ({self.state})"
|
|
||||||
self._config_apply_in_progress = True
|
|
||||||
current_cfg = self.cfg
|
|
||||||
|
|
||||||
try:
|
|
||||||
changed = _config_fields_changed(current_cfg, candidate)
|
|
||||||
if not changed:
|
|
||||||
return "applied", [], ""
|
|
||||||
|
|
||||||
next_model = None
|
|
||||||
if (
|
|
||||||
current_cfg.stt.model != candidate.stt.model
|
|
||||||
or current_cfg.stt.device != candidate.stt.device
|
|
||||||
):
|
|
||||||
try:
|
|
||||||
next_model = _build_whisper_model(candidate.stt.model, candidate.stt.device)
|
|
||||||
except Exception as exc:
|
|
||||||
return "error", [], f"stt model reload failed: {exc}"
|
|
||||||
|
|
||||||
try:
|
|
||||||
next_vocab = VocabularyEngine(candidate.vocabulary)
|
|
||||||
except Exception as exc:
|
|
||||||
return "error", [], f"vocabulary reload failed: {exc}"
|
|
||||||
|
|
||||||
with self.lock:
|
|
||||||
if self._shutdown_requested.is_set():
|
|
||||||
return "deferred", [], "shutdown in progress"
|
|
||||||
if self.state != State.IDLE:
|
|
||||||
return "deferred", [], f"daemon is busy ({self.state})"
|
|
||||||
self.cfg = candidate
|
|
||||||
if next_model is not None:
|
|
||||||
self.model = next_model
|
|
||||||
self.vocabulary = next_vocab
|
|
||||||
self._stt_hint_kwargs_cache = None
|
|
||||||
return "applied", changed, ""
|
|
||||||
finally:
|
|
||||||
with self.lock:
|
|
||||||
self._config_apply_in_progress = False
|
|
||||||
|
|
||||||
|
|
||||||
class ConfigReloader:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
daemon: Daemon,
|
|
||||||
config_path: Path,
|
|
||||||
notifier: DesktopNotifier | None = None,
|
|
||||||
poll_interval_sec: float = CONFIG_RELOAD_POLL_INTERVAL_SEC,
|
|
||||||
):
|
|
||||||
self.daemon = daemon
|
|
||||||
self.config_path = config_path
|
|
||||||
self.notifier = notifier
|
|
||||||
self.poll_interval_sec = poll_interval_sec
|
|
||||||
self._stop_event = threading.Event()
|
|
||||||
self._request_lock = threading.Lock()
|
|
||||||
self._pending_reload_reason = ""
|
|
||||||
self._pending_force = False
|
|
||||||
self._last_seen_fingerprint = self._fingerprint()
|
|
||||||
self._thread: threading.Thread | None = None
|
|
||||||
|
|
||||||
def start(self) -> None:
|
|
||||||
if self._thread is not None:
|
|
||||||
return
|
|
||||||
self._thread = threading.Thread(target=self._run, daemon=True, name="config-reloader")
|
|
||||||
self._thread.start()
|
|
||||||
|
|
||||||
def stop(self, timeout: float = 2.0) -> None:
|
|
||||||
self._stop_event.set()
|
|
||||||
if self._thread is not None:
|
|
||||||
self._thread.join(timeout=timeout)
|
|
||||||
self._thread = None
|
|
||||||
|
|
||||||
def request_reload(self, reason: str, *, force: bool = False) -> None:
|
|
||||||
with self._request_lock:
|
|
||||||
if self._pending_reload_reason:
|
|
||||||
self._pending_reload_reason = f"{self._pending_reload_reason}; {reason}"
|
|
||||||
else:
|
|
||||||
self._pending_reload_reason = reason
|
|
||||||
self._pending_force = self._pending_force or force
|
|
||||||
logging.debug("config reload requested: %s (force=%s)", reason, force)
|
|
||||||
|
|
||||||
def tick(self) -> None:
|
|
||||||
self._detect_file_change()
|
|
||||||
self._apply_if_pending()
|
|
||||||
|
|
||||||
def _run(self) -> None:
|
|
||||||
while not self._stop_event.wait(self.poll_interval_sec):
|
|
||||||
self.tick()
|
|
||||||
|
|
||||||
def _detect_file_change(self) -> None:
|
|
||||||
fingerprint = self._fingerprint()
|
|
||||||
if fingerprint == self._last_seen_fingerprint:
|
|
||||||
return
|
|
||||||
self._last_seen_fingerprint = fingerprint
|
|
||||||
self.request_reload("file change detected", force=False)
|
|
||||||
|
|
||||||
def _apply_if_pending(self) -> None:
|
|
||||||
with self._request_lock:
|
|
||||||
reason = self._pending_reload_reason
|
|
||||||
force_reload = self._pending_force
|
|
||||||
if not reason:
|
|
||||||
return
|
|
||||||
|
|
||||||
if self.daemon.get_state() != State.IDLE:
|
|
||||||
logging.debug("config reload deferred; daemon state=%s", self.daemon.get_state())
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
candidate = load(str(self.config_path))
|
|
||||||
except Exception as exc:
|
|
||||||
self._clear_pending()
|
|
||||||
msg = f"config reload failed ({reason}): {exc}"
|
|
||||||
logging.error(msg)
|
|
||||||
self._notify("Config Reload Failed", str(exc), error=True)
|
|
||||||
return
|
|
||||||
|
|
||||||
status, changed, error = self.daemon.try_apply_config(candidate)
|
|
||||||
if status == "deferred":
|
|
||||||
logging.debug("config reload deferred during apply: %s", error)
|
|
||||||
return
|
|
||||||
self._clear_pending()
|
|
||||||
if status == "error":
|
|
||||||
msg = f"config reload failed ({reason}): {error}"
|
|
||||||
logging.error(msg)
|
|
||||||
self._notify("Config Reload Failed", error, error=True)
|
|
||||||
return
|
|
||||||
|
|
||||||
if changed:
|
|
||||||
msg = ", ".join(changed)
|
|
||||||
logging.info("config reloaded (%s): %s", reason, msg)
|
|
||||||
self._notify("Config Reloaded", f"Applied: {msg}", error=False)
|
|
||||||
return
|
|
||||||
if force_reload:
|
|
||||||
logging.info("config reloaded (%s): no effective changes", reason)
|
|
||||||
self._notify("Config Reloaded", "No effective changes", error=False)
|
|
||||||
|
|
||||||
def _clear_pending(self) -> None:
|
|
||||||
with self._request_lock:
|
|
||||||
self._pending_reload_reason = ""
|
|
||||||
self._pending_force = False
|
|
||||||
|
|
||||||
def _fingerprint(self) -> str | None:
|
|
||||||
try:
|
|
||||||
data = self.config_path.read_bytes()
|
|
||||||
return hashlib.sha256(data).hexdigest()
|
|
||||||
except FileNotFoundError:
|
|
||||||
return None
|
|
||||||
except OSError as exc:
|
|
||||||
logging.warning("config fingerprint failed (%s): %s", self.config_path, exc)
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _notify(self, title: str, body: str, *, error: bool) -> None:
|
|
||||||
if self.notifier is None:
|
|
||||||
return
|
|
||||||
sent = self.notifier.send(title, body, error=error)
|
|
||||||
if not sent:
|
|
||||||
logging.debug("desktop notifications unavailable: %s", title)
|
|
||||||
|
|
||||||
|
|
||||||
class PipelineReloader:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
daemon: Daemon,
|
|
||||||
pipelines_path: Path,
|
|
||||||
notifier: DesktopNotifier | None = None,
|
|
||||||
dry_run: bool = False,
|
|
||||||
poll_interval_sec: float = CONFIG_RELOAD_POLL_INTERVAL_SEC,
|
|
||||||
):
|
|
||||||
self.daemon = daemon
|
|
||||||
self.pipelines_path = pipelines_path
|
|
||||||
self.notifier = notifier
|
|
||||||
self.dry_run = dry_run
|
|
||||||
self.poll_interval_sec = poll_interval_sec
|
|
||||||
self._stop_event = threading.Event()
|
|
||||||
self._request_lock = threading.Lock()
|
|
||||||
self._pending_reload_reason = ""
|
|
||||||
self._pending_force = False
|
|
||||||
self._last_seen_fingerprint = pipelines_fingerprint(self.pipelines_path)
|
|
||||||
self._thread: threading.Thread | None = None
|
|
||||||
|
|
||||||
def start(self) -> None:
|
|
||||||
if self._thread is not None:
|
|
||||||
return
|
|
||||||
self._thread = threading.Thread(target=self._run, daemon=True, name="pipeline-reloader")
|
|
||||||
self._thread.start()
|
|
||||||
|
|
||||||
def stop(self, timeout: float = 2.0) -> None:
|
|
||||||
self._stop_event.set()
|
|
||||||
if self._thread is not None:
|
|
||||||
self._thread.join(timeout=timeout)
|
|
||||||
self._thread = None
|
|
||||||
|
|
||||||
def request_reload(self, reason: str, *, force: bool = False) -> None:
|
|
||||||
with self._request_lock:
|
|
||||||
if self._pending_reload_reason:
|
|
||||||
self._pending_reload_reason = f"{self._pending_reload_reason}; {reason}"
|
|
||||||
else:
|
|
||||||
self._pending_reload_reason = reason
|
|
||||||
self._pending_force = self._pending_force or force
|
|
||||||
logging.debug("pipeline reload requested: %s (force=%s)", reason, force)
|
|
||||||
|
|
||||||
def tick(self) -> None:
|
|
||||||
self._detect_file_change()
|
|
||||||
self._apply_if_pending()
|
|
||||||
|
|
||||||
def _run(self) -> None:
|
|
||||||
while not self._stop_event.wait(self.poll_interval_sec):
|
|
||||||
self.tick()
|
|
||||||
|
|
||||||
def _detect_file_change(self) -> None:
|
|
||||||
fingerprint = pipelines_fingerprint(self.pipelines_path)
|
|
||||||
if fingerprint == self._last_seen_fingerprint:
|
|
||||||
return
|
|
||||||
self._last_seen_fingerprint = fingerprint
|
|
||||||
self.request_reload("file change detected", force=False)
|
|
||||||
|
|
||||||
def _apply_if_pending(self) -> None:
|
|
||||||
with self._request_lock:
|
|
||||||
reason = self._pending_reload_reason
|
|
||||||
force_reload = self._pending_force
|
|
||||||
if not reason:
|
|
||||||
return
|
|
||||||
if self.daemon.get_state() != State.IDLE:
|
|
||||||
logging.debug("pipeline reload deferred; daemon state=%s", self.daemon.get_state())
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
bindings = load_bindings(
|
|
||||||
path=self.pipelines_path,
|
|
||||||
default_hotkey=self.daemon.cfg.daemon.hotkey,
|
|
||||||
default_handler_factory=self.daemon.build_reference_pipeline,
|
|
||||||
)
|
|
||||||
callbacks = self.daemon.build_hotkey_callbacks(
|
|
||||||
list(bindings.keys()),
|
|
||||||
dry_run=self.dry_run,
|
|
||||||
)
|
|
||||||
except Exception as exc:
|
|
||||||
self._clear_pending()
|
|
||||||
logging.error("pipeline reload failed (%s): %s", reason, exc)
|
|
||||||
self._notify("Pipelines Reload Failed", str(exc), error=True)
|
|
||||||
return
|
|
||||||
|
|
||||||
status, error = self.daemon.apply_pipeline_bindings(bindings, callbacks)
|
|
||||||
if status == "deferred":
|
|
||||||
logging.debug("pipeline reload deferred during apply: %s", error)
|
|
||||||
return
|
|
||||||
self._clear_pending()
|
|
||||||
if status == "error":
|
|
||||||
logging.error("pipeline reload failed (%s): %s", reason, error)
|
|
||||||
self._notify("Pipelines Reload Failed", error, error=True)
|
|
||||||
return
|
|
||||||
|
|
||||||
hotkeys = ", ".join(sorted(bindings.keys()))
|
|
||||||
logging.info("pipelines reloaded (%s): %s", reason, hotkeys)
|
|
||||||
self._notify("Pipelines Reloaded", f"Hotkeys: {hotkeys}", error=False)
|
|
||||||
if force_reload and not bindings:
|
|
||||||
logging.info("pipelines reloaded (%s): no hotkeys", reason)
|
|
||||||
|
|
||||||
def _clear_pending(self) -> None:
|
|
||||||
with self._request_lock:
|
|
||||||
self._pending_reload_reason = ""
|
|
||||||
self._pending_force = False
|
|
||||||
|
|
||||||
def _notify(self, title: str, body: str, *, error: bool) -> None:
|
|
||||||
if self.notifier is None:
|
|
||||||
return
|
|
||||||
sent = self.notifier.send(title, body, error=error)
|
|
||||||
if not sent:
|
|
||||||
logging.debug("desktop notifications unavailable: %s", title)
|
|
||||||
|
|
||||||
|
|
||||||
def _read_lock_pid(lock_file) -> str:
|
def _read_lock_pid(lock_file) -> str:
|
||||||
|
|
@ -869,49 +366,19 @@ def main():
|
||||||
level=logging.DEBUG if args.verbose else logging.INFO,
|
level=logging.DEBUG if args.verbose else logging.INFO,
|
||||||
format="aman: %(asctime)s %(levelname)s %(message)s",
|
format="aman: %(asctime)s %(levelname)s %(message)s",
|
||||||
)
|
)
|
||||||
config_path = Path(args.config) if args.config else DEFAULT_CONFIG_PATH
|
cfg = load(args.config)
|
||||||
cfg = load(str(config_path))
|
|
||||||
_LOCK_HANDLE = _lock_single_instance()
|
_LOCK_HANDLE = _lock_single_instance()
|
||||||
|
|
||||||
logging.info("hotkey: %s", cfg.daemon.hotkey)
|
logging.info("hotkey: %s", cfg.daemon.hotkey)
|
||||||
logging.info(
|
logging.info(
|
||||||
"config (%s):\n%s",
|
"config (%s):\n%s",
|
||||||
str(config_path),
|
args.config or str(Path.home() / ".config" / "aman" / "config.json"),
|
||||||
json.dumps(redacted_dict(cfg), indent=2),
|
json.dumps(redacted_dict(cfg), indent=2),
|
||||||
)
|
)
|
||||||
|
|
||||||
config_reloader = None
|
|
||||||
pipeline_reloader = None
|
|
||||||
try:
|
try:
|
||||||
desktop = get_desktop_adapter()
|
desktop = get_desktop_adapter()
|
||||||
daemon = Daemon(cfg, desktop, verbose=args.verbose)
|
daemon = Daemon(cfg, desktop, verbose=args.verbose)
|
||||||
notifier = DesktopNotifier(app_name="aman")
|
|
||||||
config_reloader = ConfigReloader(
|
|
||||||
daemon=daemon,
|
|
||||||
config_path=config_path,
|
|
||||||
notifier=notifier,
|
|
||||||
poll_interval_sec=CONFIG_RELOAD_POLL_INTERVAL_SEC,
|
|
||||||
)
|
|
||||||
pipelines_path = DEFAULT_PIPELINES_PATH
|
|
||||||
initial_bindings = load_bindings(
|
|
||||||
path=pipelines_path,
|
|
||||||
default_hotkey=cfg.daemon.hotkey,
|
|
||||||
default_handler_factory=daemon.build_reference_pipeline,
|
|
||||||
)
|
|
||||||
initial_callbacks = daemon.build_hotkey_callbacks(
|
|
||||||
list(initial_bindings.keys()),
|
|
||||||
dry_run=args.dry_run,
|
|
||||||
)
|
|
||||||
status, error = daemon.apply_pipeline_bindings(initial_bindings, initial_callbacks)
|
|
||||||
if status != "applied":
|
|
||||||
raise RuntimeError(f"pipeline setup failed: {error}")
|
|
||||||
pipeline_reloader = PipelineReloader(
|
|
||||||
daemon=daemon,
|
|
||||||
pipelines_path=pipelines_path,
|
|
||||||
notifier=notifier,
|
|
||||||
dry_run=args.dry_run,
|
|
||||||
poll_interval_sec=CONFIG_RELOAD_POLL_INTERVAL_SEC,
|
|
||||||
)
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logging.error("startup failed: %s", exc)
|
logging.error("startup failed: %s", exc)
|
||||||
raise SystemExit(1)
|
raise SystemExit(1)
|
||||||
|
|
@ -923,10 +390,6 @@ def main():
|
||||||
return
|
return
|
||||||
shutdown_once.set()
|
shutdown_once.set()
|
||||||
logging.info("%s, shutting down", reason)
|
logging.info("%s, shutting down", reason)
|
||||||
if config_reloader is not None:
|
|
||||||
config_reloader.stop(timeout=2.0)
|
|
||||||
if pipeline_reloader is not None:
|
|
||||||
pipeline_reloader.stop(timeout=2.0)
|
|
||||||
if not daemon.shutdown(timeout=5.0):
|
if not daemon.shutdown(timeout=5.0):
|
||||||
logging.warning("timed out waiting for idle state during shutdown")
|
logging.warning("timed out waiting for idle state during shutdown")
|
||||||
desktop.request_quit()
|
desktop.request_quit()
|
||||||
|
|
@ -934,30 +397,15 @@ def main():
|
||||||
def handle_signal(_sig, _frame):
|
def handle_signal(_sig, _frame):
|
||||||
threading.Thread(target=shutdown, args=("signal received",), daemon=True).start()
|
threading.Thread(target=shutdown, args=("signal received",), daemon=True).start()
|
||||||
|
|
||||||
def handle_reload_signal(_sig, _frame):
|
|
||||||
if shutdown_once.is_set():
|
|
||||||
return
|
|
||||||
if config_reloader is not None:
|
|
||||||
threading.Thread(
|
|
||||||
target=lambda: config_reloader.request_reload("signal SIGHUP", force=True),
|
|
||||||
daemon=True,
|
|
||||||
).start()
|
|
||||||
if pipeline_reloader is not None:
|
|
||||||
threading.Thread(
|
|
||||||
target=lambda: pipeline_reloader.request_reload("signal SIGHUP", force=True),
|
|
||||||
daemon=True,
|
|
||||||
).start()
|
|
||||||
|
|
||||||
signal.signal(signal.SIGINT, handle_signal)
|
signal.signal(signal.SIGINT, handle_signal)
|
||||||
signal.signal(signal.SIGTERM, handle_signal)
|
signal.signal(signal.SIGTERM, handle_signal)
|
||||||
signal.signal(signal.SIGHUP, handle_reload_signal)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
desktop.start_hotkey_listener(
|
||||||
|
cfg.daemon.hotkey,
|
||||||
|
lambda: logging.info("hotkey pressed (dry-run)") if args.dry_run else daemon.toggle(),
|
||||||
|
)
|
||||||
desktop.start_cancel_listener(lambda: daemon.cancel_recording())
|
desktop.start_cancel_listener(lambda: daemon.cancel_recording())
|
||||||
if config_reloader is not None:
|
|
||||||
config_reloader.start()
|
|
||||||
if pipeline_reloader is not None:
|
|
||||||
pipeline_reloader.start()
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logging.error("hotkey setup failed: %s", exc)
|
logging.error("hotkey setup failed: %s", exc)
|
||||||
raise SystemExit(1)
|
raise SystemExit(1)
|
||||||
|
|
@ -965,10 +413,6 @@ def main():
|
||||||
try:
|
try:
|
||||||
desktop.run_tray(daemon.get_state, lambda: shutdown("quit requested"))
|
desktop.run_tray(daemon.get_state, lambda: shutdown("quit requested"))
|
||||||
finally:
|
finally:
|
||||||
if config_reloader is not None:
|
|
||||||
config_reloader.stop(timeout=2.0)
|
|
||||||
if pipeline_reloader is not None:
|
|
||||||
pipeline_reloader.stop(timeout=2.0)
|
|
||||||
daemon.shutdown(timeout=1.0)
|
daemon.shutdown(timeout=1.0)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -51,6 +51,11 @@ class VocabularyConfig:
|
||||||
terms: list[str] = field(default_factory=list)
|
terms: list[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DomainInferenceConfig:
|
||||||
|
enabled: bool = True
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Config:
|
class Config:
|
||||||
daemon: DaemonConfig = field(default_factory=DaemonConfig)
|
daemon: DaemonConfig = field(default_factory=DaemonConfig)
|
||||||
|
|
@ -58,6 +63,7 @@ class Config:
|
||||||
stt: SttConfig = field(default_factory=SttConfig)
|
stt: SttConfig = field(default_factory=SttConfig)
|
||||||
injection: InjectionConfig = field(default_factory=InjectionConfig)
|
injection: InjectionConfig = field(default_factory=InjectionConfig)
|
||||||
vocabulary: VocabularyConfig = field(default_factory=VocabularyConfig)
|
vocabulary: VocabularyConfig = field(default_factory=VocabularyConfig)
|
||||||
|
domain_inference: DomainInferenceConfig = field(default_factory=DomainInferenceConfig)
|
||||||
|
|
||||||
|
|
||||||
def load(path: str | None) -> Config:
|
def load(path: str | None) -> Config:
|
||||||
|
|
@ -118,8 +124,20 @@ def validate(cfg: Config) -> None:
|
||||||
cfg.vocabulary.replacements = _validate_replacements(cfg.vocabulary.replacements)
|
cfg.vocabulary.replacements = _validate_replacements(cfg.vocabulary.replacements)
|
||||||
cfg.vocabulary.terms = _validate_terms(cfg.vocabulary.terms)
|
cfg.vocabulary.terms = _validate_terms(cfg.vocabulary.terms)
|
||||||
|
|
||||||
|
if not isinstance(cfg.domain_inference.enabled, bool):
|
||||||
|
raise ValueError("domain_inference.enabled must be boolean")
|
||||||
|
|
||||||
|
|
||||||
def _from_dict(data: dict[str, Any], cfg: Config) -> Config:
|
def _from_dict(data: dict[str, Any], cfg: Config) -> Config:
|
||||||
|
if "logging" in data:
|
||||||
|
raise ValueError("logging section is no longer supported; use -v/--verbose")
|
||||||
|
if "log_transcript" in data:
|
||||||
|
raise ValueError("log_transcript is no longer supported; use -v/--verbose")
|
||||||
|
if "ai" in data:
|
||||||
|
raise ValueError("ai section is no longer supported")
|
||||||
|
if "ai_enabled" in data:
|
||||||
|
raise ValueError("ai_enabled is no longer supported")
|
||||||
|
|
||||||
has_sections = any(
|
has_sections = any(
|
||||||
key in data
|
key in data
|
||||||
for key in (
|
for key in (
|
||||||
|
|
@ -128,6 +146,7 @@ def _from_dict(data: dict[str, Any], cfg: Config) -> Config:
|
||||||
"stt",
|
"stt",
|
||||||
"injection",
|
"injection",
|
||||||
"vocabulary",
|
"vocabulary",
|
||||||
|
"domain_inference",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if has_sections:
|
if has_sections:
|
||||||
|
|
@ -136,6 +155,7 @@ def _from_dict(data: dict[str, Any], cfg: Config) -> Config:
|
||||||
stt = _ensure_dict(data.get("stt"), "stt")
|
stt = _ensure_dict(data.get("stt"), "stt")
|
||||||
injection = _ensure_dict(data.get("injection"), "injection")
|
injection = _ensure_dict(data.get("injection"), "injection")
|
||||||
vocabulary = _ensure_dict(data.get("vocabulary"), "vocabulary")
|
vocabulary = _ensure_dict(data.get("vocabulary"), "vocabulary")
|
||||||
|
domain_inference = _ensure_dict(data.get("domain_inference"), "domain_inference")
|
||||||
|
|
||||||
if "hotkey" in daemon:
|
if "hotkey" in daemon:
|
||||||
cfg.daemon.hotkey = _as_nonempty_str(daemon["hotkey"], "daemon.hotkey")
|
cfg.daemon.hotkey = _as_nonempty_str(daemon["hotkey"], "daemon.hotkey")
|
||||||
|
|
@ -156,7 +176,28 @@ def _from_dict(data: dict[str, Any], cfg: Config) -> Config:
|
||||||
cfg.vocabulary.replacements = _as_replacements(vocabulary["replacements"])
|
cfg.vocabulary.replacements = _as_replacements(vocabulary["replacements"])
|
||||||
if "terms" in vocabulary:
|
if "terms" in vocabulary:
|
||||||
cfg.vocabulary.terms = _as_terms(vocabulary["terms"])
|
cfg.vocabulary.terms = _as_terms(vocabulary["terms"])
|
||||||
|
if "max_rules" in vocabulary:
|
||||||
|
raise ValueError("vocabulary.max_rules is no longer supported")
|
||||||
|
if "max_terms" in vocabulary:
|
||||||
|
raise ValueError("vocabulary.max_terms is no longer supported")
|
||||||
|
if "enabled" in domain_inference:
|
||||||
|
cfg.domain_inference.enabled = _as_bool(
|
||||||
|
domain_inference["enabled"], "domain_inference.enabled"
|
||||||
|
)
|
||||||
|
if "mode" in domain_inference:
|
||||||
|
raise ValueError("domain_inference.mode is no longer supported")
|
||||||
return cfg
|
return cfg
|
||||||
|
|
||||||
|
if "hotkey" in data:
|
||||||
|
cfg.daemon.hotkey = _as_nonempty_str(data["hotkey"], "hotkey")
|
||||||
|
if "input" in data:
|
||||||
|
cfg.recording.input = _as_recording_input(data["input"])
|
||||||
|
if "whisper_model" in data:
|
||||||
|
cfg.stt.model = _as_nonempty_str(data["whisper_model"], "whisper_model")
|
||||||
|
if "whisper_device" in data:
|
||||||
|
cfg.stt.device = _as_nonempty_str(data["whisper_device"], "whisper_device")
|
||||||
|
if "injection_backend" in data:
|
||||||
|
cfg.injection.backend = _as_nonempty_str(data["injection_backend"], "injection_backend")
|
||||||
return cfg
|
return cfg
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,6 @@ from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_CONFIG_PATH = Path.home() / ".config" / "aman" / "config.json"
|
DEFAULT_CONFIG_PATH = Path.home() / ".config" / "aman" / "config.json"
|
||||||
CONFIG_RELOAD_POLL_INTERVAL_SEC = 1.0
|
|
||||||
RECORD_TIMEOUT_SEC = 300
|
RECORD_TIMEOUT_SEC = 300
|
||||||
STT_LANGUAGE = "en"
|
STT_LANGUAGE = "en"
|
||||||
TRAY_UPDATE_MS = 250
|
TRAY_UPDATE_MS = 250
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ from typing import Callable, Protocol
|
||||||
|
|
||||||
|
|
||||||
class DesktopAdapter(Protocol):
|
class DesktopAdapter(Protocol):
|
||||||
def set_hotkeys(self, bindings: dict[str, Callable[[], None]]) -> None:
|
def start_hotkey_listener(self, hotkey: str, callback: Callable[[], None]) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def start_cancel_listener(self, callback: Callable[[], None]) -> None:
|
def start_cancel_listener(self, callback: Callable[[], None]) -> None:
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ from typing import Callable
|
||||||
|
|
||||||
|
|
||||||
class WaylandAdapter:
|
class WaylandAdapter:
|
||||||
def set_hotkeys(self, _bindings: dict[str, Callable[[], None]]) -> None:
|
def start_hotkey_listener(self, _hotkey: str, _callback: Callable[[], None]) -> None:
|
||||||
raise SystemExit("Wayland hotkeys are not supported yet.")
|
raise SystemExit("Wayland hotkeys are not supported yet.")
|
||||||
|
|
||||||
def start_cancel_listener(self, _callback: Callable[[], None]) -> None:
|
def start_cancel_listener(self, _callback: Callable[[], None]) -> None:
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ import logging
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Any, Callable, Iterable
|
from typing import Callable, Iterable
|
||||||
|
|
||||||
import gi
|
import gi
|
||||||
from Xlib import X, XK, display
|
from Xlib import X, XK, display
|
||||||
|
|
@ -43,8 +43,6 @@ class X11Adapter:
|
||||||
self.indicator = None
|
self.indicator = None
|
||||||
self.status_icon = None
|
self.status_icon = None
|
||||||
self.menu = None
|
self.menu = None
|
||||||
self._hotkey_listener_lock = threading.Lock()
|
|
||||||
self._hotkey_listeners: dict[str, dict[str, Any]] = {}
|
|
||||||
if AppIndicator3 is not None:
|
if AppIndicator3 is not None:
|
||||||
self.indicator = AppIndicator3.Indicator.new(
|
self.indicator = AppIndicator3.Indicator.new(
|
||||||
"aman",
|
"aman",
|
||||||
|
|
@ -67,29 +65,15 @@ class X11Adapter:
|
||||||
if self.menu:
|
if self.menu:
|
||||||
self.menu.popup(None, None, None, None, 0, _time)
|
self.menu.popup(None, None, None, None, 0, _time)
|
||||||
|
|
||||||
def set_hotkeys(self, bindings: dict[str, Callable[[], None]]) -> None:
|
def start_hotkey_listener(self, hotkey: str, callback: Callable[[], None]) -> None:
|
||||||
if not isinstance(bindings, dict):
|
mods, keysym = self._parse_hotkey(hotkey)
|
||||||
raise ValueError("bindings must be a dictionary")
|
self._validate_hotkey_registration(mods, keysym)
|
||||||
next_listeners: dict[str, dict[str, Any]] = {}
|
thread = threading.Thread(target=self._listen, args=(mods, keysym, callback), daemon=True)
|
||||||
try:
|
thread.start()
|
||||||
for hotkey, callback in bindings.items():
|
|
||||||
if not callable(callback):
|
|
||||||
raise ValueError(f"callback for hotkey {hotkey} must be callable")
|
|
||||||
next_listeners[hotkey] = self._spawn_hotkey_listener(hotkey, callback)
|
|
||||||
except Exception:
|
|
||||||
for listener in next_listeners.values():
|
|
||||||
self._stop_hotkey_listener(listener)
|
|
||||||
raise
|
|
||||||
|
|
||||||
with self._hotkey_listener_lock:
|
|
||||||
previous = self._hotkey_listeners
|
|
||||||
self._hotkey_listeners = next_listeners
|
|
||||||
for listener in previous.values():
|
|
||||||
self._stop_hotkey_listener(listener)
|
|
||||||
|
|
||||||
def start_cancel_listener(self, callback: Callable[[], None]) -> None:
|
def start_cancel_listener(self, callback: Callable[[], None]) -> None:
|
||||||
mods, keysym = self._parse_hotkey("Escape")
|
mods, keysym = self._parse_hotkey("Escape")
|
||||||
thread = threading.Thread(target=self._listen, args=(mods, keysym, callback, threading.Event()), daemon=True)
|
thread = threading.Thread(target=self._listen, args=(mods, keysym, callback), daemon=True)
|
||||||
thread.start()
|
thread.start()
|
||||||
|
|
||||||
def inject_text(
|
def inject_text(
|
||||||
|
|
@ -143,14 +127,7 @@ class X11Adapter:
|
||||||
finally:
|
finally:
|
||||||
self.request_quit()
|
self.request_quit()
|
||||||
|
|
||||||
def _listen(
|
def _listen(self, mods: int, keysym: int, callback: Callable[[], None]) -> None:
|
||||||
self,
|
|
||||||
mods: int,
|
|
||||||
keysym: int,
|
|
||||||
callback: Callable[[], None],
|
|
||||||
stop_event: threading.Event,
|
|
||||||
listener_meta: dict[str, Any] | None = None,
|
|
||||||
) -> None:
|
|
||||||
disp = None
|
disp = None
|
||||||
root = None
|
root = None
|
||||||
keycode = None
|
keycode = None
|
||||||
|
|
@ -158,26 +135,14 @@ class X11Adapter:
|
||||||
disp = display.Display()
|
disp = display.Display()
|
||||||
root = disp.screen().root
|
root = disp.screen().root
|
||||||
keycode = self._grab_hotkey(disp, root, mods, keysym)
|
keycode = self._grab_hotkey(disp, root, mods, keysym)
|
||||||
if listener_meta is not None:
|
while True:
|
||||||
listener_meta["display"] = disp
|
|
||||||
listener_meta["root"] = root
|
|
||||||
listener_meta["keycode"] = keycode
|
|
||||||
listener_meta["ready"].set()
|
|
||||||
while not stop_event.is_set():
|
|
||||||
if disp.pending_events() == 0:
|
|
||||||
time.sleep(0.05)
|
|
||||||
continue
|
|
||||||
ev = disp.next_event()
|
ev = disp.next_event()
|
||||||
if ev.type == X.KeyPress and ev.detail == keycode:
|
if ev.type == X.KeyPress and ev.detail == keycode:
|
||||||
state = ev.state & ~(X.LockMask | X.Mod2Mask)
|
state = ev.state & ~(X.LockMask | X.Mod2Mask)
|
||||||
if state == mods:
|
if state == mods:
|
||||||
callback()
|
callback()
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
if listener_meta is not None:
|
logging.error("hotkey listener stopped: %s", exc)
|
||||||
listener_meta["error"] = exc
|
|
||||||
listener_meta["ready"].set()
|
|
||||||
if not stop_event.is_set():
|
|
||||||
logging.error("hotkey listener stopped: %s", exc)
|
|
||||||
finally:
|
finally:
|
||||||
if root is not None and keycode is not None and disp is not None:
|
if root is not None and keycode is not None and disp is not None:
|
||||||
try:
|
try:
|
||||||
|
|
@ -185,13 +150,6 @@ class X11Adapter:
|
||||||
disp.sync()
|
disp.sync()
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
if disp is not None:
|
|
||||||
try:
|
|
||||||
disp.close()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
if listener_meta is not None:
|
|
||||||
listener_meta["ready"].set()
|
|
||||||
|
|
||||||
def _parse_hotkey(self, hotkey: str):
|
def _parse_hotkey(self, hotkey: str):
|
||||||
mods = 0
|
mods = 0
|
||||||
|
|
@ -227,51 +185,6 @@ class X11Adapter:
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _spawn_hotkey_listener(self, hotkey: str, callback: Callable[[], None]) -> dict[str, Any]:
|
|
||||||
mods, keysym = self._parse_hotkey(hotkey)
|
|
||||||
self._validate_hotkey_registration(mods, keysym)
|
|
||||||
stop_event = threading.Event()
|
|
||||||
listener_meta: dict[str, Any] = {
|
|
||||||
"hotkey": hotkey,
|
|
||||||
"mods": mods,
|
|
||||||
"keysym": keysym,
|
|
||||||
"stop_event": stop_event,
|
|
||||||
"display": None,
|
|
||||||
"root": None,
|
|
||||||
"keycode": None,
|
|
||||||
"error": None,
|
|
||||||
"ready": threading.Event(),
|
|
||||||
}
|
|
||||||
thread = threading.Thread(
|
|
||||||
target=self._listen,
|
|
||||||
args=(mods, keysym, callback, stop_event, listener_meta),
|
|
||||||
daemon=True,
|
|
||||||
)
|
|
||||||
listener_meta["thread"] = thread
|
|
||||||
thread.start()
|
|
||||||
if not listener_meta["ready"].wait(timeout=2.0):
|
|
||||||
stop_event.set()
|
|
||||||
raise RuntimeError("hotkey listener setup timed out")
|
|
||||||
if listener_meta["error"] is not None:
|
|
||||||
stop_event.set()
|
|
||||||
raise listener_meta["error"]
|
|
||||||
if listener_meta["keycode"] is None:
|
|
||||||
stop_event.set()
|
|
||||||
raise RuntimeError("hotkey listener failed to initialize")
|
|
||||||
return listener_meta
|
|
||||||
|
|
||||||
def _stop_hotkey_listener(self, listener: dict[str, Any]) -> None:
|
|
||||||
listener["stop_event"].set()
|
|
||||||
disp = listener.get("display")
|
|
||||||
if disp is not None:
|
|
||||||
try:
|
|
||||||
disp.close()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
thread = listener.get("thread")
|
|
||||||
if thread is not None:
|
|
||||||
thread.join(timeout=1.0)
|
|
||||||
|
|
||||||
def _grab_hotkey(self, disp, root, mods, keysym):
|
def _grab_hotkey(self, disp, root, mods, keysym):
|
||||||
keycode = disp.keysym_to_keycode(keysym)
|
keycode = disp.keysym_to_keycode(keysym)
|
||||||
if keycode == 0:
|
if keycode == 0:
|
||||||
|
|
|
||||||
|
|
@ -1,86 +0,0 @@
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Any, Callable
|
|
||||||
|
|
||||||
|
|
||||||
ALLOWED_FAILURE_POLICIES = {"best_effort", "strict"}
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class PipelineOptions:
|
|
||||||
failure_policy: str = "best_effort"
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
policy = (self.failure_policy or "").strip().lower()
|
|
||||||
if policy not in ALLOWED_FAILURE_POLICIES:
|
|
||||||
allowed = ", ".join(sorted(ALLOWED_FAILURE_POLICIES))
|
|
||||||
raise ValueError(f"failure_policy must be one of: {allowed}")
|
|
||||||
object.__setattr__(self, "failure_policy", policy)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class PipelineBinding:
|
|
||||||
hotkey: str
|
|
||||||
handler: Callable[[Any, "PipelineLib"], str]
|
|
||||||
options: PipelineOptions
|
|
||||||
|
|
||||||
|
|
||||||
class PipelineLib:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
transcribe_fn: Callable[[Any], str],
|
|
||||||
llm_fn: Callable[..., str],
|
|
||||||
):
|
|
||||||
self._transcribe_fn = transcribe_fn
|
|
||||||
self._llm_fn = llm_fn
|
|
||||||
|
|
||||||
def transcribe(
|
|
||||||
self,
|
|
||||||
audio: Any,
|
|
||||||
*,
|
|
||||||
hints: list[str] | None = None,
|
|
||||||
whisper_opts: dict[str, Any] | None = None,
|
|
||||||
) -> str:
|
|
||||||
return self._transcribe_fn(
|
|
||||||
audio,
|
|
||||||
hints=hints,
|
|
||||||
whisper_opts=whisper_opts,
|
|
||||||
)
|
|
||||||
|
|
||||||
def llm(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
system_prompt: str,
|
|
||||||
user_prompt: str,
|
|
||||||
llm_opts: dict[str, Any] | None = None,
|
|
||||||
) -> str:
|
|
||||||
return self._llm_fn(
|
|
||||||
system_prompt=system_prompt,
|
|
||||||
user_prompt=user_prompt,
|
|
||||||
llm_opts=llm_opts,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Engine:
|
|
||||||
def __init__(self, lib: PipelineLib):
|
|
||||||
self.lib = lib
|
|
||||||
|
|
||||||
def run(self, binding: PipelineBinding, audio: Any) -> str:
|
|
||||||
try:
|
|
||||||
output = binding.handler(audio, self.lib)
|
|
||||||
except Exception as exc:
|
|
||||||
if binding.options.failure_policy == "strict":
|
|
||||||
raise
|
|
||||||
logging.error("pipeline failed for hotkey %s: %s", binding.hotkey, exc)
|
|
||||||
return ""
|
|
||||||
|
|
||||||
if output is None:
|
|
||||||
return ""
|
|
||||||
if not isinstance(output, str):
|
|
||||||
raise RuntimeError(
|
|
||||||
f"pipeline for hotkey {binding.hotkey} returned non-string output"
|
|
||||||
)
|
|
||||||
return output
|
|
||||||
|
|
@ -1,40 +0,0 @@
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
|
|
||||||
class DesktopNotifier:
|
|
||||||
def __init__(self, app_name: str = "aman"):
|
|
||||||
self._app_name = app_name
|
|
||||||
self._notify = None
|
|
||||||
self._ready = False
|
|
||||||
self._attempted = False
|
|
||||||
|
|
||||||
def send(self, title: str, body: str, *, error: bool = False) -> bool:
|
|
||||||
_ = error
|
|
||||||
notify = self._ensure_backend()
|
|
||||||
if notify is None:
|
|
||||||
return False
|
|
||||||
try:
|
|
||||||
notification = notify.Notification.new(title, body, None)
|
|
||||||
notification.show()
|
|
||||||
return True
|
|
||||||
except Exception:
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _ensure_backend(self):
|
|
||||||
if self._attempted:
|
|
||||||
return self._notify if self._ready else None
|
|
||||||
|
|
||||||
self._attempted = True
|
|
||||||
try:
|
|
||||||
import gi
|
|
||||||
|
|
||||||
gi.require_version("Notify", "0.7")
|
|
||||||
from gi.repository import Notify # type: ignore[import-not-found]
|
|
||||||
|
|
||||||
if Notify.init(self._app_name):
|
|
||||||
self._notify = Notify
|
|
||||||
self._ready = True
|
|
||||||
except Exception:
|
|
||||||
self._notify = None
|
|
||||||
self._ready = False
|
|
||||||
return self._notify if self._ready else None
|
|
||||||
|
|
@ -1,97 +0,0 @@
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import hashlib
|
|
||||||
import importlib.util
|
|
||||||
from pathlib import Path
|
|
||||||
from types import ModuleType
|
|
||||||
from typing import Any, Callable
|
|
||||||
|
|
||||||
from engine import PipelineBinding, PipelineOptions
|
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_PIPELINES_PATH = Path.home() / ".config" / "aman" / "pipelines.py"
|
|
||||||
|
|
||||||
|
|
||||||
def fingerprint(path: Path) -> str | None:
|
|
||||||
try:
|
|
||||||
data = path.read_bytes()
|
|
||||||
except FileNotFoundError:
|
|
||||||
return None
|
|
||||||
return hashlib.sha256(data).hexdigest()
|
|
||||||
|
|
||||||
|
|
||||||
def load_bindings(
|
|
||||||
*,
|
|
||||||
path: Path,
|
|
||||||
default_hotkey: str,
|
|
||||||
default_handler_factory: Callable[[], Callable[[Any, Any], str]],
|
|
||||||
) -> dict[str, PipelineBinding]:
|
|
||||||
if not path.exists():
|
|
||||||
handler = default_handler_factory()
|
|
||||||
options = PipelineOptions(failure_policy="best_effort")
|
|
||||||
return {
|
|
||||||
default_hotkey: PipelineBinding(
|
|
||||||
hotkey=default_hotkey,
|
|
||||||
handler=handler,
|
|
||||||
options=options,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
module = _load_module(path)
|
|
||||||
return _bindings_from_module(module)
|
|
||||||
|
|
||||||
|
|
||||||
def _load_module(path: Path) -> ModuleType:
|
|
||||||
module_name = f"aman_user_pipelines_{hash(path.resolve())}"
|
|
||||||
spec = importlib.util.spec_from_file_location(module_name, str(path))
|
|
||||||
if spec is None or spec.loader is None:
|
|
||||||
raise ValueError(f"unable to load pipelines module: {path}")
|
|
||||||
module = importlib.util.module_from_spec(spec)
|
|
||||||
spec.loader.exec_module(module)
|
|
||||||
return module
|
|
||||||
|
|
||||||
|
|
||||||
def _bindings_from_module(module: ModuleType) -> dict[str, PipelineBinding]:
|
|
||||||
raw_pipelines = getattr(module, "HOTKEY_PIPELINES", None)
|
|
||||||
if not isinstance(raw_pipelines, dict):
|
|
||||||
raise ValueError("HOTKEY_PIPELINES must be a dictionary")
|
|
||||||
if not raw_pipelines:
|
|
||||||
raise ValueError("HOTKEY_PIPELINES cannot be empty")
|
|
||||||
|
|
||||||
raw_options = getattr(module, "PIPELINE_OPTIONS", {})
|
|
||||||
if raw_options is None:
|
|
||||||
raw_options = {}
|
|
||||||
if not isinstance(raw_options, dict):
|
|
||||||
raise ValueError("PIPELINE_OPTIONS must be a dictionary when provided")
|
|
||||||
|
|
||||||
bindings: dict[str, PipelineBinding] = {}
|
|
||||||
for key, handler in raw_pipelines.items():
|
|
||||||
hotkey = _validate_hotkey(key)
|
|
||||||
if not callable(handler):
|
|
||||||
raise ValueError(f"HOTKEY_PIPELINES[{hotkey}] must be callable")
|
|
||||||
option_data = raw_options.get(hotkey, {})
|
|
||||||
options = _parse_options(hotkey, option_data)
|
|
||||||
bindings[hotkey] = PipelineBinding(
|
|
||||||
hotkey=hotkey,
|
|
||||||
handler=handler,
|
|
||||||
options=options,
|
|
||||||
)
|
|
||||||
return bindings
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_hotkey(value: Any) -> str:
|
|
||||||
if not isinstance(value, str):
|
|
||||||
raise ValueError("pipeline hotkey keys must be strings")
|
|
||||||
hotkey = value.strip()
|
|
||||||
if not hotkey:
|
|
||||||
raise ValueError("pipeline hotkey keys cannot be empty")
|
|
||||||
return hotkey
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_options(hotkey: str, value: Any) -> PipelineOptions:
|
|
||||||
if value is None:
|
|
||||||
return PipelineOptions()
|
|
||||||
if not isinstance(value, dict):
|
|
||||||
raise ValueError(f"PIPELINE_OPTIONS[{hotkey}] must be an object")
|
|
||||||
failure_policy = value.get("failure_policy", "best_effort")
|
|
||||||
return PipelineOptions(failure_policy=failure_policy)
|
|
||||||
|
|
@ -4,7 +4,101 @@ import re
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Iterable
|
from typing import Iterable
|
||||||
|
|
||||||
from config import VocabularyConfig
|
from config import DomainInferenceConfig, VocabularyConfig
|
||||||
|
|
||||||
|
|
||||||
|
DOMAIN_GENERAL = "general"
|
||||||
|
DOMAIN_PERSONAL_NAMES = "personal_names"
|
||||||
|
DOMAIN_SOFTWARE_DEV = "software_dev"
|
||||||
|
DOMAIN_OPS_INFRA = "ops_infra"
|
||||||
|
DOMAIN_BUSINESS = "business"
|
||||||
|
DOMAIN_MEDICAL_LEGAL = "medical_legal"
|
||||||
|
|
||||||
|
DOMAIN_ORDER = (
|
||||||
|
DOMAIN_PERSONAL_NAMES,
|
||||||
|
DOMAIN_SOFTWARE_DEV,
|
||||||
|
DOMAIN_OPS_INFRA,
|
||||||
|
DOMAIN_BUSINESS,
|
||||||
|
DOMAIN_MEDICAL_LEGAL,
|
||||||
|
)
|
||||||
|
|
||||||
|
DOMAIN_KEYWORDS = {
|
||||||
|
DOMAIN_SOFTWARE_DEV: {
|
||||||
|
"api",
|
||||||
|
"bug",
|
||||||
|
"code",
|
||||||
|
"commit",
|
||||||
|
"docker",
|
||||||
|
"function",
|
||||||
|
"git",
|
||||||
|
"github",
|
||||||
|
"javascript",
|
||||||
|
"python",
|
||||||
|
"refactor",
|
||||||
|
"repository",
|
||||||
|
"typescript",
|
||||||
|
"unit",
|
||||||
|
"test",
|
||||||
|
},
|
||||||
|
DOMAIN_OPS_INFRA: {
|
||||||
|
"cluster",
|
||||||
|
"container",
|
||||||
|
"deploy",
|
||||||
|
"deployment",
|
||||||
|
"incident",
|
||||||
|
"kubernetes",
|
||||||
|
"monitoring",
|
||||||
|
"nginx",
|
||||||
|
"pod",
|
||||||
|
"prod",
|
||||||
|
"service",
|
||||||
|
"systemd",
|
||||||
|
"terraform",
|
||||||
|
},
|
||||||
|
DOMAIN_BUSINESS: {
|
||||||
|
"budget",
|
||||||
|
"client",
|
||||||
|
"deadline",
|
||||||
|
"finance",
|
||||||
|
"invoice",
|
||||||
|
"meeting",
|
||||||
|
"milestone",
|
||||||
|
"project",
|
||||||
|
"quarter",
|
||||||
|
"roadmap",
|
||||||
|
"sales",
|
||||||
|
"stakeholder",
|
||||||
|
},
|
||||||
|
DOMAIN_MEDICAL_LEGAL: {
|
||||||
|
"agreement",
|
||||||
|
"case",
|
||||||
|
"claim",
|
||||||
|
"compliance",
|
||||||
|
"contract",
|
||||||
|
"diagnosis",
|
||||||
|
"liability",
|
||||||
|
"patient",
|
||||||
|
"prescription",
|
||||||
|
"regulation",
|
||||||
|
"symptom",
|
||||||
|
"treatment",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
DOMAIN_PHRASES = {
|
||||||
|
DOMAIN_SOFTWARE_DEV: ("pull request", "code review", "integration test"),
|
||||||
|
DOMAIN_OPS_INFRA: ("on call", "service restart", "roll back"),
|
||||||
|
DOMAIN_BUSINESS: ("follow up", "action items", "meeting notes"),
|
||||||
|
DOMAIN_MEDICAL_LEGAL: ("terms and conditions", "medical record", "legal review"),
|
||||||
|
}
|
||||||
|
|
||||||
|
GREETING_TOKENS = {"hello", "hi", "hey", "good morning", "good afternoon", "good evening"}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class DomainResult:
|
||||||
|
name: str
|
||||||
|
confidence: float
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
|
|
@ -14,9 +108,10 @@ class _ReplacementView:
|
||||||
|
|
||||||
|
|
||||||
class VocabularyEngine:
|
class VocabularyEngine:
|
||||||
def __init__(self, vocab_cfg: VocabularyConfig):
|
def __init__(self, vocab_cfg: VocabularyConfig, domain_cfg: DomainInferenceConfig):
|
||||||
self._replacements = [_ReplacementView(r.source, r.target) for r in vocab_cfg.replacements]
|
self._replacements = [_ReplacementView(r.source, r.target) for r in vocab_cfg.replacements]
|
||||||
self._terms = list(vocab_cfg.terms)
|
self._terms = list(vocab_cfg.terms)
|
||||||
|
self._domain_enabled = bool(domain_cfg.enabled)
|
||||||
|
|
||||||
self._replacement_map = {
|
self._replacement_map = {
|
||||||
_normalize_key(rule.source): rule.target for rule in self._replacements
|
_normalize_key(rule.source): rule.target for rule in self._replacements
|
||||||
|
|
@ -66,6 +161,55 @@ class VocabularyEngine:
|
||||||
used += addition
|
used += addition
|
||||||
return "\n".join(out)
|
return "\n".join(out)
|
||||||
|
|
||||||
|
def infer_domain(self, text: str) -> DomainResult:
|
||||||
|
if not self._domain_enabled:
|
||||||
|
return DomainResult(name=DOMAIN_GENERAL, confidence=0.0)
|
||||||
|
|
||||||
|
normalized = text.casefold()
|
||||||
|
tokens = re.findall(r"[a-z0-9+#./_-]+", normalized)
|
||||||
|
if not tokens:
|
||||||
|
return DomainResult(name=DOMAIN_GENERAL, confidence=0.0)
|
||||||
|
|
||||||
|
scores = {domain: 0 for domain in DOMAIN_ORDER}
|
||||||
|
for token in tokens:
|
||||||
|
for domain, keywords in DOMAIN_KEYWORDS.items():
|
||||||
|
if token in keywords:
|
||||||
|
scores[domain] += 2
|
||||||
|
|
||||||
|
for domain, phrases in DOMAIN_PHRASES.items():
|
||||||
|
for phrase in phrases:
|
||||||
|
if phrase in normalized:
|
||||||
|
scores[domain] += 2
|
||||||
|
|
||||||
|
if any(token in GREETING_TOKENS for token in tokens):
|
||||||
|
scores[DOMAIN_PERSONAL_NAMES] += 1
|
||||||
|
|
||||||
|
# Boost domains from configured dictionary terms and replacement targets.
|
||||||
|
dictionary_tokens = self._dictionary_tokens()
|
||||||
|
for token in dictionary_tokens:
|
||||||
|
for domain, keywords in DOMAIN_KEYWORDS.items():
|
||||||
|
if token in keywords and token in tokens:
|
||||||
|
scores[domain] += 1
|
||||||
|
|
||||||
|
top_domain = DOMAIN_GENERAL
|
||||||
|
top_score = 0
|
||||||
|
total_score = 0
|
||||||
|
for domain in DOMAIN_ORDER:
|
||||||
|
score = scores[domain]
|
||||||
|
total_score += score
|
||||||
|
if score > top_score:
|
||||||
|
top_score = score
|
||||||
|
top_domain = domain
|
||||||
|
|
||||||
|
if top_score < 2 or total_score == 0:
|
||||||
|
return DomainResult(name=DOMAIN_GENERAL, confidence=0.0)
|
||||||
|
|
||||||
|
confidence = top_score / total_score
|
||||||
|
if confidence < 0.45:
|
||||||
|
return DomainResult(name=DOMAIN_GENERAL, confidence=0.0)
|
||||||
|
|
||||||
|
return DomainResult(name=top_domain, confidence=round(confidence, 2))
|
||||||
|
|
||||||
def _build_stt_hotwords(self, *, limit: int, char_budget: int) -> str:
|
def _build_stt_hotwords(self, *, limit: int, char_budget: int) -> str:
|
||||||
items = _dedupe_preserve_order(
|
items = _dedupe_preserve_order(
|
||||||
[rule.target for rule in self._replacements] + self._terms
|
[rule.target for rule in self._replacements] + self._terms
|
||||||
|
|
@ -92,6 +236,19 @@ class VocabularyEngine:
|
||||||
return ""
|
return ""
|
||||||
return prefix + hotwords
|
return prefix + hotwords
|
||||||
|
|
||||||
|
def _dictionary_tokens(self) -> set[str]:
|
||||||
|
values: list[str] = []
|
||||||
|
for rule in self._replacements:
|
||||||
|
values.append(rule.source)
|
||||||
|
values.append(rule.target)
|
||||||
|
values.extend(self._terms)
|
||||||
|
|
||||||
|
tokens: set[str] = set()
|
||||||
|
for value in values:
|
||||||
|
for token in re.findall(r"[a-z0-9+#./_-]+", value.casefold()):
|
||||||
|
tokens.add(token)
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
|
||||||
def _build_replacement_pattern(sources: Iterable[str]) -> re.Pattern[str] | None:
|
def _build_replacement_pattern(sources: Iterable[str]) -> re.Pattern[str] | None:
|
||||||
unique_sources = _dedupe_preserve_order(list(sources))
|
unique_sources = _dedupe_preserve_order(list(sources))
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
import json
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
|
|
@ -12,25 +11,14 @@ if str(SRC) not in sys.path:
|
||||||
sys.path.insert(0, str(SRC))
|
sys.path.insert(0, str(SRC))
|
||||||
|
|
||||||
import aman
|
import aman
|
||||||
from config import Config, VocabularyReplacement, redacted_dict
|
from config import Config, VocabularyReplacement
|
||||||
from engine import PipelineBinding, PipelineOptions
|
|
||||||
|
|
||||||
|
|
||||||
class FakeDesktop:
|
class FakeDesktop:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.inject_calls = []
|
self.inject_calls = []
|
||||||
self.hotkey_updates = []
|
|
||||||
self.hotkeys = {}
|
|
||||||
self.cancel_callback = None
|
|
||||||
self.quit_calls = 0
|
self.quit_calls = 0
|
||||||
|
|
||||||
def set_hotkeys(self, bindings):
|
|
||||||
self.hotkeys = dict(bindings)
|
|
||||||
self.hotkey_updates.append(tuple(sorted(bindings.keys())))
|
|
||||||
|
|
||||||
def start_cancel_listener(self, callback):
|
|
||||||
self.cancel_callback = callback
|
|
||||||
|
|
||||||
def inject_text(
|
def inject_text(
|
||||||
self,
|
self,
|
||||||
text: str,
|
text: str,
|
||||||
|
|
@ -88,30 +76,12 @@ class FakeAIProcessor:
|
||||||
def process(self, text, lang="en", **_kwargs):
|
def process(self, text, lang="en", **_kwargs):
|
||||||
return text
|
return text
|
||||||
|
|
||||||
def chat(self, *, system_prompt, user_prompt, llm_opts=None):
|
|
||||||
_ = system_prompt
|
|
||||||
opts = llm_opts or {}
|
|
||||||
if "response_format" in opts:
|
|
||||||
payload = json.loads(user_prompt)
|
|
||||||
transcript = payload.get("transcript", "")
|
|
||||||
return json.dumps({"cleaned_text": transcript})
|
|
||||||
return "general"
|
|
||||||
|
|
||||||
|
|
||||||
class FakeAudio:
|
class FakeAudio:
|
||||||
def __init__(self, size: int):
|
def __init__(self, size: int):
|
||||||
self.size = size
|
self.size = size
|
||||||
|
|
||||||
|
|
||||||
class FakeNotifier:
|
|
||||||
def __init__(self):
|
|
||||||
self.events = []
|
|
||||||
|
|
||||||
def send(self, title, body, *, error=False):
|
|
||||||
self.events.append((title, body, error))
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
class DaemonTests(unittest.TestCase):
|
class DaemonTests(unittest.TestCase):
|
||||||
def _config(self) -> Config:
|
def _config(self) -> Config:
|
||||||
cfg = Config()
|
cfg = Config()
|
||||||
|
|
@ -133,23 +103,19 @@ class DaemonTests(unittest.TestCase):
|
||||||
|
|
||||||
@patch("aman.stop_audio_recording", return_value=FakeAudio(8))
|
@patch("aman.stop_audio_recording", return_value=FakeAudio(8))
|
||||||
@patch("aman.start_audio_recording", return_value=(object(), object()))
|
@patch("aman.start_audio_recording", return_value=(object(), object()))
|
||||||
def test_hotkey_start_stop_injects_text(self, _start_mock, _stop_mock):
|
def test_toggle_start_stop_injects_text(self, _start_mock, _stop_mock):
|
||||||
desktop = FakeDesktop()
|
desktop = FakeDesktop()
|
||||||
daemon = self._build_daemon(desktop, FakeModel(), verbose=False)
|
daemon = self._build_daemon(desktop, FakeModel(), verbose=False)
|
||||||
daemon._start_stop_worker = (
|
daemon._start_stop_worker = (
|
||||||
lambda stream, record, trigger, process_audio: daemon._stop_and_process(
|
lambda stream, record, trigger, process_audio: daemon._stop_and_process(
|
||||||
stream,
|
stream, record, trigger, process_audio
|
||||||
record,
|
|
||||||
trigger,
|
|
||||||
process_audio,
|
|
||||||
daemon._pending_worker_hotkey,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
daemon.handle_hotkey(daemon.cfg.daemon.hotkey)
|
daemon.toggle()
|
||||||
self.assertEqual(daemon.get_state(), aman.State.RECORDING)
|
self.assertEqual(daemon.get_state(), aman.State.RECORDING)
|
||||||
|
|
||||||
daemon.handle_hotkey(daemon.cfg.daemon.hotkey)
|
daemon.toggle()
|
||||||
|
|
||||||
self.assertEqual(daemon.get_state(), aman.State.IDLE)
|
self.assertEqual(daemon.get_state(), aman.State.IDLE)
|
||||||
self.assertEqual(desktop.inject_calls, [("hello world", "clipboard", False)])
|
self.assertEqual(desktop.inject_calls, [("hello world", "clipboard", False)])
|
||||||
|
|
@ -165,7 +131,7 @@ class DaemonTests(unittest.TestCase):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
daemon.handle_hotkey(daemon.cfg.daemon.hotkey)
|
daemon.toggle()
|
||||||
self.assertEqual(daemon.get_state(), aman.State.RECORDING)
|
self.assertEqual(daemon.get_state(), aman.State.RECORDING)
|
||||||
|
|
||||||
self.assertTrue(daemon.shutdown(timeout=0.2))
|
self.assertTrue(daemon.shutdown(timeout=0.2))
|
||||||
|
|
@ -187,8 +153,8 @@ class DaemonTests(unittest.TestCase):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
daemon.handle_hotkey(daemon.cfg.daemon.hotkey)
|
daemon.toggle()
|
||||||
daemon.handle_hotkey(daemon.cfg.daemon.hotkey)
|
daemon.toggle()
|
||||||
|
|
||||||
self.assertEqual(desktop.inject_calls, [("good morning Marta", "clipboard", False)])
|
self.assertEqual(desktop.inject_calls, [("good morning Marta", "clipboard", False)])
|
||||||
|
|
||||||
|
|
@ -257,8 +223,8 @@ class DaemonTests(unittest.TestCase):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
daemon.handle_hotkey(daemon.cfg.daemon.hotkey)
|
daemon.toggle()
|
||||||
daemon.handle_hotkey(daemon.cfg.daemon.hotkey)
|
daemon.toggle()
|
||||||
|
|
||||||
self.assertEqual(desktop.inject_calls, [("hello world", "clipboard", True)])
|
self.assertEqual(desktop.inject_calls, [("hello world", "clipboard", True)])
|
||||||
|
|
||||||
|
|
@ -273,161 +239,6 @@ class DaemonTests(unittest.TestCase):
|
||||||
any("DEBUG:root:state: idle -> recording" in line for line in logs.output)
|
any("DEBUG:root:state: idle -> recording" in line for line in logs.output)
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_hotkey_dispatches_to_matching_pipeline(self):
|
|
||||||
desktop = FakeDesktop()
|
|
||||||
daemon = self._build_daemon(desktop, FakeModel(), verbose=False)
|
|
||||||
daemon.set_pipeline_bindings(
|
|
||||||
{
|
|
||||||
"Super+m": PipelineBinding(
|
|
||||||
hotkey="Super+m",
|
|
||||||
handler=lambda _audio, _lib: "default",
|
|
||||||
options=PipelineOptions(failure_policy="best_effort"),
|
|
||||||
),
|
|
||||||
"Super+Shift+m": PipelineBinding(
|
|
||||||
hotkey="Super+Shift+m",
|
|
||||||
handler=lambda _audio, _lib: "caps",
|
|
||||||
options=PipelineOptions(failure_policy="best_effort"),
|
|
||||||
),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
out = daemon._run_pipeline(object(), "Super+Shift+m")
|
|
||||||
self.assertEqual(out, "caps")
|
|
||||||
|
|
||||||
def test_try_apply_config_applies_new_runtime_settings(self):
|
|
||||||
desktop = FakeDesktop()
|
|
||||||
cfg = self._config()
|
|
||||||
daemon = self._build_daemon(desktop, FakeModel(), cfg=cfg, verbose=False)
|
|
||||||
daemon._stt_hint_kwargs_cache = {"hotwords": "old"}
|
|
||||||
|
|
||||||
candidate = Config()
|
|
||||||
candidate.injection.backend = "injection"
|
|
||||||
candidate.vocabulary.replacements = [
|
|
||||||
VocabularyReplacement(source="Martha", target="Marta"),
|
|
||||||
]
|
|
||||||
|
|
||||||
status, changed, error = daemon.try_apply_config(candidate)
|
|
||||||
|
|
||||||
self.assertEqual(status, "applied")
|
|
||||||
self.assertEqual(error, "")
|
|
||||||
self.assertIn("injection.backend", changed)
|
|
||||||
self.assertIn("vocabulary.replacements", changed)
|
|
||||||
self.assertEqual(daemon.cfg.injection.backend, "injection")
|
|
||||||
self.assertEqual(
|
|
||||||
daemon.vocabulary.apply_deterministic_replacements("Martha"),
|
|
||||||
"Marta",
|
|
||||||
)
|
|
||||||
self.assertIsNone(daemon._stt_hint_kwargs_cache)
|
|
||||||
|
|
||||||
def test_try_apply_config_reloads_stt_model(self):
|
|
||||||
desktop = FakeDesktop()
|
|
||||||
cfg = self._config()
|
|
||||||
daemon = self._build_daemon(desktop, FakeModel(), cfg=cfg, verbose=False)
|
|
||||||
|
|
||||||
candidate = Config()
|
|
||||||
candidate.stt.model = "small"
|
|
||||||
next_model = FakeModel(text="from-new-model")
|
|
||||||
with patch("aman._build_whisper_model", return_value=next_model):
|
|
||||||
status, changed, error = daemon.try_apply_config(candidate)
|
|
||||||
|
|
||||||
self.assertEqual(status, "applied")
|
|
||||||
self.assertEqual(error, "")
|
|
||||||
self.assertIn("stt.model", changed)
|
|
||||||
self.assertIs(daemon.model, next_model)
|
|
||||||
|
|
||||||
def test_try_apply_config_is_deferred_while_busy(self):
|
|
||||||
desktop = FakeDesktop()
|
|
||||||
daemon = self._build_daemon(desktop, FakeModel(), verbose=False)
|
|
||||||
daemon.set_state(aman.State.RECORDING)
|
|
||||||
|
|
||||||
status, changed, error = daemon.try_apply_config(Config())
|
|
||||||
|
|
||||||
self.assertEqual(status, "deferred")
|
|
||||||
self.assertEqual(changed, [])
|
|
||||||
self.assertIn("busy", error)
|
|
||||||
|
|
||||||
|
|
||||||
class ConfigReloaderTests(unittest.TestCase):
|
|
||||||
def _write_config(self, path: Path, cfg: Config):
|
|
||||||
path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
path.write_text(json.dumps(redacted_dict(cfg)), encoding="utf-8")
|
|
||||||
|
|
||||||
def _daemon_for_path(self, path: Path) -> tuple[aman.Daemon, FakeDesktop]:
|
|
||||||
cfg = Config()
|
|
||||||
self._write_config(path, cfg)
|
|
||||||
desktop = FakeDesktop()
|
|
||||||
with patch("aman._build_whisper_model", return_value=FakeModel()), patch(
|
|
||||||
"aman.LlamaProcessor", return_value=FakeAIProcessor()
|
|
||||||
):
|
|
||||||
daemon = aman.Daemon(cfg, desktop, verbose=False)
|
|
||||||
return daemon, desktop
|
|
||||||
|
|
||||||
def test_reloader_applies_changed_config(self):
|
|
||||||
with tempfile.TemporaryDirectory() as td:
|
|
||||||
path = Path(td) / "config.json"
|
|
||||||
daemon, _desktop = self._daemon_for_path(path)
|
|
||||||
notifier = FakeNotifier()
|
|
||||||
reloader = aman.ConfigReloader(
|
|
||||||
daemon=daemon,
|
|
||||||
config_path=path,
|
|
||||||
notifier=notifier,
|
|
||||||
poll_interval_sec=0.01,
|
|
||||||
)
|
|
||||||
|
|
||||||
updated = Config()
|
|
||||||
updated.injection.backend = "injection"
|
|
||||||
self._write_config(path, updated)
|
|
||||||
reloader.tick()
|
|
||||||
|
|
||||||
self.assertEqual(daemon.cfg.injection.backend, "injection")
|
|
||||||
self.assertTrue(any(evt[0] == "Config Reloaded" and not evt[2] for evt in notifier.events))
|
|
||||||
|
|
||||||
def test_reloader_keeps_last_good_config_when_invalid(self):
|
|
||||||
with tempfile.TemporaryDirectory() as td:
|
|
||||||
path = Path(td) / "config.json"
|
|
||||||
daemon, _desktop = self._daemon_for_path(path)
|
|
||||||
notifier = FakeNotifier()
|
|
||||||
reloader = aman.ConfigReloader(
|
|
||||||
daemon=daemon,
|
|
||||||
config_path=path,
|
|
||||||
notifier=notifier,
|
|
||||||
poll_interval_sec=0.01,
|
|
||||||
)
|
|
||||||
|
|
||||||
path.write_text('{"injection":{"backend":"invalid"}}', encoding="utf-8")
|
|
||||||
reloader.tick()
|
|
||||||
self.assertEqual(daemon.cfg.injection.backend, "clipboard")
|
|
||||||
fail_events = [evt for evt in notifier.events if evt[0] == "Config Reload Failed"]
|
|
||||||
self.assertEqual(len(fail_events), 1)
|
|
||||||
|
|
||||||
reloader.tick()
|
|
||||||
fail_events = [evt for evt in notifier.events if evt[0] == "Config Reload Failed"]
|
|
||||||
self.assertEqual(len(fail_events), 1)
|
|
||||||
|
|
||||||
def test_reloader_defers_apply_until_idle(self):
|
|
||||||
with tempfile.TemporaryDirectory() as td:
|
|
||||||
path = Path(td) / "config.json"
|
|
||||||
daemon, _desktop = self._daemon_for_path(path)
|
|
||||||
notifier = FakeNotifier()
|
|
||||||
reloader = aman.ConfigReloader(
|
|
||||||
daemon=daemon,
|
|
||||||
config_path=path,
|
|
||||||
notifier=notifier,
|
|
||||||
poll_interval_sec=0.01,
|
|
||||||
)
|
|
||||||
|
|
||||||
updated = Config()
|
|
||||||
updated.injection.backend = "injection"
|
|
||||||
self._write_config(path, updated)
|
|
||||||
|
|
||||||
daemon.set_state(aman.State.RECORDING)
|
|
||||||
reloader.tick()
|
|
||||||
self.assertEqual(daemon.cfg.injection.backend, "clipboard")
|
|
||||||
|
|
||||||
daemon.set_state(aman.State.IDLE)
|
|
||||||
reloader.tick()
|
|
||||||
self.assertEqual(daemon.cfg.injection.backend, "injection")
|
|
||||||
|
|
||||||
|
|
||||||
class LockTests(unittest.TestCase):
|
class LockTests(unittest.TestCase):
|
||||||
def test_lock_rejects_second_instance(self):
|
def test_lock_rejects_second_instance(self):
|
||||||
|
|
|
||||||
|
|
@ -26,6 +26,7 @@ class ConfigTests(unittest.TestCase):
|
||||||
self.assertFalse(cfg.injection.remove_transcription_from_clipboard)
|
self.assertFalse(cfg.injection.remove_transcription_from_clipboard)
|
||||||
self.assertEqual(cfg.vocabulary.replacements, [])
|
self.assertEqual(cfg.vocabulary.replacements, [])
|
||||||
self.assertEqual(cfg.vocabulary.terms, [])
|
self.assertEqual(cfg.vocabulary.terms, [])
|
||||||
|
self.assertTrue(cfg.domain_inference.enabled)
|
||||||
|
|
||||||
self.assertTrue(missing.exists())
|
self.assertTrue(missing.exists())
|
||||||
written = json.loads(missing.read_text(encoding="utf-8"))
|
written = json.loads(missing.read_text(encoding="utf-8"))
|
||||||
|
|
@ -47,6 +48,7 @@ class ConfigTests(unittest.TestCase):
|
||||||
],
|
],
|
||||||
"terms": ["Systemd", "Kubernetes"],
|
"terms": ["Systemd", "Kubernetes"],
|
||||||
},
|
},
|
||||||
|
"domain_inference": {"enabled": True},
|
||||||
}
|
}
|
||||||
with tempfile.TemporaryDirectory() as td:
|
with tempfile.TemporaryDirectory() as td:
|
||||||
path = Path(td) / "config.json"
|
path = Path(td) / "config.json"
|
||||||
|
|
@ -64,6 +66,7 @@ class ConfigTests(unittest.TestCase):
|
||||||
self.assertEqual(cfg.vocabulary.replacements[0].source, "Martha")
|
self.assertEqual(cfg.vocabulary.replacements[0].source, "Martha")
|
||||||
self.assertEqual(cfg.vocabulary.replacements[0].target, "Marta")
|
self.assertEqual(cfg.vocabulary.replacements[0].target, "Marta")
|
||||||
self.assertEqual(cfg.vocabulary.terms, ["Systemd", "Kubernetes"])
|
self.assertEqual(cfg.vocabulary.terms, ["Systemd", "Kubernetes"])
|
||||||
|
self.assertTrue(cfg.domain_inference.enabled)
|
||||||
|
|
||||||
def test_super_modifier_hotkey_is_valid(self):
|
def test_super_modifier_hotkey_is_valid(self):
|
||||||
payload = {"daemon": {"hotkey": "Super+m"}}
|
payload = {"daemon": {"hotkey": "Super+m"}}
|
||||||
|
|
@ -95,6 +98,28 @@ class ConfigTests(unittest.TestCase):
|
||||||
):
|
):
|
||||||
load(str(path))
|
load(str(path))
|
||||||
|
|
||||||
|
def test_loads_legacy_keys(self):
|
||||||
|
payload = {
|
||||||
|
"hotkey": "Alt+m",
|
||||||
|
"input": "Mic",
|
||||||
|
"whisper_model": "tiny",
|
||||||
|
"whisper_device": "cpu",
|
||||||
|
"injection_backend": "clipboard",
|
||||||
|
}
|
||||||
|
with tempfile.TemporaryDirectory() as td:
|
||||||
|
path = Path(td) / "config.json"
|
||||||
|
path.write_text(json.dumps(payload), encoding="utf-8")
|
||||||
|
|
||||||
|
cfg = load(str(path))
|
||||||
|
|
||||||
|
self.assertEqual(cfg.daemon.hotkey, "Alt+m")
|
||||||
|
self.assertEqual(cfg.recording.input, "Mic")
|
||||||
|
self.assertEqual(cfg.stt.model, "tiny")
|
||||||
|
self.assertEqual(cfg.stt.device, "cpu")
|
||||||
|
self.assertEqual(cfg.injection.backend, "clipboard")
|
||||||
|
self.assertFalse(cfg.injection.remove_transcription_from_clipboard)
|
||||||
|
self.assertEqual(cfg.vocabulary.replacements, [])
|
||||||
|
|
||||||
def test_invalid_injection_backend_raises(self):
|
def test_invalid_injection_backend_raises(self):
|
||||||
payload = {"injection": {"backend": "invalid"}}
|
payload = {"injection": {"backend": "invalid"}}
|
||||||
with tempfile.TemporaryDirectory() as td:
|
with tempfile.TemporaryDirectory() as td:
|
||||||
|
|
@ -113,20 +138,41 @@ class ConfigTests(unittest.TestCase):
|
||||||
with self.assertRaisesRegex(ValueError, "injection.remove_transcription_from_clipboard"):
|
with self.assertRaisesRegex(ValueError, "injection.remove_transcription_from_clipboard"):
|
||||||
load(str(path))
|
load(str(path))
|
||||||
|
|
||||||
def test_unknown_top_level_fields_are_ignored(self):
|
def test_removed_ai_section_raises(self):
|
||||||
payload = {
|
payload = {"ai": {"enabled": True}}
|
||||||
"custom_a": {"enabled": True},
|
|
||||||
"custom_b": {"nested": "value"},
|
|
||||||
"custom_c": 123,
|
|
||||||
}
|
|
||||||
with tempfile.TemporaryDirectory() as td:
|
with tempfile.TemporaryDirectory() as td:
|
||||||
path = Path(td) / "config.json"
|
path = Path(td) / "config.json"
|
||||||
path.write_text(json.dumps(payload), encoding="utf-8")
|
path.write_text(json.dumps(payload), encoding="utf-8")
|
||||||
|
|
||||||
cfg = load(str(path))
|
with self.assertRaisesRegex(ValueError, "ai section is no longer supported"):
|
||||||
|
load(str(path))
|
||||||
|
|
||||||
self.assertEqual(cfg.daemon.hotkey, "Cmd+m")
|
def test_removed_legacy_ai_enabled_raises(self):
|
||||||
self.assertEqual(cfg.injection.backend, "clipboard")
|
payload = {"ai_enabled": True}
|
||||||
|
with tempfile.TemporaryDirectory() as td:
|
||||||
|
path = Path(td) / "config.json"
|
||||||
|
path.write_text(json.dumps(payload), encoding="utf-8")
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(ValueError, "ai_enabled is no longer supported"):
|
||||||
|
load(str(path))
|
||||||
|
|
||||||
|
def test_removed_logging_section_raises(self):
|
||||||
|
payload = {"logging": {"log_transcript": True}}
|
||||||
|
with tempfile.TemporaryDirectory() as td:
|
||||||
|
path = Path(td) / "config.json"
|
||||||
|
path.write_text(json.dumps(payload), encoding="utf-8")
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(ValueError, "no longer supported"):
|
||||||
|
load(str(path))
|
||||||
|
|
||||||
|
def test_removed_legacy_log_transcript_raises(self):
|
||||||
|
payload = {"log_transcript": True}
|
||||||
|
with tempfile.TemporaryDirectory() as td:
|
||||||
|
path = Path(td) / "config.json"
|
||||||
|
path.write_text(json.dumps(payload), encoding="utf-8")
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(ValueError, "no longer supported"):
|
||||||
|
load(str(path))
|
||||||
|
|
||||||
def test_conflicting_replacements_raise(self):
|
def test_conflicting_replacements_raise(self):
|
||||||
payload = {
|
payload = {
|
||||||
|
|
@ -178,15 +224,32 @@ class ConfigTests(unittest.TestCase):
|
||||||
with self.assertRaisesRegex(ValueError, "wildcard"):
|
with self.assertRaisesRegex(ValueError, "wildcard"):
|
||||||
load(str(path))
|
load(str(path))
|
||||||
|
|
||||||
def test_unknown_vocabulary_fields_are_ignored(self):
|
def test_removed_domain_mode_raises(self):
|
||||||
payload = {"vocabulary": {"custom_limit": 100, "custom_extra": 200, "terms": ["Docker"]}}
|
payload = {"domain_inference": {"mode": "heuristic"}}
|
||||||
with tempfile.TemporaryDirectory() as td:
|
with tempfile.TemporaryDirectory() as td:
|
||||||
path = Path(td) / "config.json"
|
path = Path(td) / "config.json"
|
||||||
path.write_text(json.dumps(payload), encoding="utf-8")
|
path.write_text(json.dumps(payload), encoding="utf-8")
|
||||||
|
|
||||||
cfg = load(str(path))
|
with self.assertRaisesRegex(ValueError, "domain_inference.mode is no longer supported"):
|
||||||
|
load(str(path))
|
||||||
|
|
||||||
self.assertEqual(cfg.vocabulary.terms, ["Docker"])
|
def test_removed_vocabulary_max_rules_raises(self):
|
||||||
|
payload = {"vocabulary": {"max_rules": 100}}
|
||||||
|
with tempfile.TemporaryDirectory() as td:
|
||||||
|
path = Path(td) / "config.json"
|
||||||
|
path.write_text(json.dumps(payload), encoding="utf-8")
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(ValueError, "vocabulary.max_rules is no longer supported"):
|
||||||
|
load(str(path))
|
||||||
|
|
||||||
|
def test_removed_vocabulary_max_terms_raises(self):
|
||||||
|
payload = {"vocabulary": {"max_terms": 100}}
|
||||||
|
with tempfile.TemporaryDirectory() as td:
|
||||||
|
path = Path(td) / "config.json"
|
||||||
|
path.write_text(json.dumps(payload), encoding="utf-8")
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(ValueError, "vocabulary.max_terms is no longer supported"):
|
||||||
|
load(str(path))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
|
|
@ -1,92 +0,0 @@
|
||||||
import sys
|
|
||||||
import unittest
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
ROOT = Path(__file__).resolve().parents[1]
|
|
||||||
SRC = ROOT / "src"
|
|
||||||
if str(SRC) not in sys.path:
|
|
||||||
sys.path.insert(0, str(SRC))
|
|
||||||
|
|
||||||
from engine import Engine, PipelineBinding, PipelineLib, PipelineOptions
|
|
||||||
|
|
||||||
|
|
||||||
class EngineTests(unittest.TestCase):
|
|
||||||
def test_best_effort_pipeline_failure_returns_empty_string(self):
|
|
||||||
lib = PipelineLib(
|
|
||||||
transcribe_fn=lambda *_args, **_kwargs: "text",
|
|
||||||
llm_fn=lambda **_kwargs: "result",
|
|
||||||
)
|
|
||||||
engine = Engine(lib)
|
|
||||||
|
|
||||||
def failing_pipeline(_audio, _lib):
|
|
||||||
raise RuntimeError("boom")
|
|
||||||
|
|
||||||
binding = PipelineBinding(
|
|
||||||
hotkey="Cmd+m",
|
|
||||||
handler=failing_pipeline,
|
|
||||||
options=PipelineOptions(failure_policy="best_effort"),
|
|
||||||
)
|
|
||||||
out = engine.run(binding, object())
|
|
||||||
self.assertEqual(out, "")
|
|
||||||
|
|
||||||
def test_strict_pipeline_failure_raises(self):
|
|
||||||
lib = PipelineLib(
|
|
||||||
transcribe_fn=lambda *_args, **_kwargs: "text",
|
|
||||||
llm_fn=lambda **_kwargs: "result",
|
|
||||||
)
|
|
||||||
engine = Engine(lib)
|
|
||||||
|
|
||||||
def failing_pipeline(_audio, _lib):
|
|
||||||
raise RuntimeError("boom")
|
|
||||||
|
|
||||||
binding = PipelineBinding(
|
|
||||||
hotkey="Cmd+m",
|
|
||||||
handler=failing_pipeline,
|
|
||||||
options=PipelineOptions(failure_policy="strict"),
|
|
||||||
)
|
|
||||||
with self.assertRaisesRegex(RuntimeError, "boom"):
|
|
||||||
engine.run(binding, object())
|
|
||||||
|
|
||||||
def test_pipeline_lib_forwards_arguments(self):
|
|
||||||
seen = {}
|
|
||||||
|
|
||||||
def transcribe_fn(audio, *, hints=None, whisper_opts=None):
|
|
||||||
seen["audio"] = audio
|
|
||||||
seen["hints"] = hints
|
|
||||||
seen["whisper_opts"] = whisper_opts
|
|
||||||
return "hello"
|
|
||||||
|
|
||||||
def llm_fn(*, system_prompt, user_prompt, llm_opts=None):
|
|
||||||
seen["system_prompt"] = system_prompt
|
|
||||||
seen["user_prompt"] = user_prompt
|
|
||||||
seen["llm_opts"] = llm_opts
|
|
||||||
return "world"
|
|
||||||
|
|
||||||
lib = PipelineLib(
|
|
||||||
transcribe_fn=transcribe_fn,
|
|
||||||
llm_fn=llm_fn,
|
|
||||||
)
|
|
||||||
|
|
||||||
audio = object()
|
|
||||||
self.assertEqual(
|
|
||||||
lib.transcribe(audio, hints=["Docker"], whisper_opts={"vad_filter": True}),
|
|
||||||
"hello",
|
|
||||||
)
|
|
||||||
self.assertEqual(
|
|
||||||
lib.llm(
|
|
||||||
system_prompt="sys",
|
|
||||||
user_prompt="user",
|
|
||||||
llm_opts={"temperature": 0.2},
|
|
||||||
),
|
|
||||||
"world",
|
|
||||||
)
|
|
||||||
self.assertIs(seen["audio"], audio)
|
|
||||||
self.assertEqual(seen["hints"], ["Docker"])
|
|
||||||
self.assertEqual(seen["whisper_opts"], {"vad_filter": True})
|
|
||||||
self.assertEqual(seen["system_prompt"], "sys")
|
|
||||||
self.assertEqual(seen["user_prompt"], "user")
|
|
||||||
self.assertEqual(seen["llm_opts"], {"temperature": 0.2})
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
unittest.main()
|
|
||||||
|
|
@ -1,109 +0,0 @@
|
||||||
import tempfile
|
|
||||||
import textwrap
|
|
||||||
import sys
|
|
||||||
import unittest
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
ROOT = Path(__file__).resolve().parents[1]
|
|
||||||
SRC = ROOT / "src"
|
|
||||||
if str(SRC) not in sys.path:
|
|
||||||
sys.path.insert(0, str(SRC))
|
|
||||||
|
|
||||||
from pipelines_runtime import fingerprint, load_bindings
|
|
||||||
|
|
||||||
|
|
||||||
class PipelinesRuntimeTests(unittest.TestCase):
|
|
||||||
def test_missing_file_uses_default_binding(self):
|
|
||||||
with tempfile.TemporaryDirectory() as td:
|
|
||||||
path = Path(td) / "pipelines.py"
|
|
||||||
called = {"count": 0}
|
|
||||||
|
|
||||||
def factory():
|
|
||||||
called["count"] += 1
|
|
||||||
|
|
||||||
def handler(_audio, _lib):
|
|
||||||
return "ok"
|
|
||||||
|
|
||||||
return handler
|
|
||||||
|
|
||||||
bindings = load_bindings(
|
|
||||||
path=path,
|
|
||||||
default_hotkey="Cmd+m",
|
|
||||||
default_handler_factory=factory,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertEqual(list(bindings.keys()), ["Cmd+m"])
|
|
||||||
self.assertEqual(called["count"], 1)
|
|
||||||
|
|
||||||
def test_loads_module_bindings_and_options(self):
|
|
||||||
with tempfile.TemporaryDirectory() as td:
|
|
||||||
path = Path(td) / "pipelines.py"
|
|
||||||
path.write_text(
|
|
||||||
textwrap.dedent(
|
|
||||||
"""
|
|
||||||
def p1(audio, lib):
|
|
||||||
return "one"
|
|
||||||
|
|
||||||
def p2(audio, lib):
|
|
||||||
return "two"
|
|
||||||
|
|
||||||
HOTKEY_PIPELINES = {
|
|
||||||
"Cmd+m": p1,
|
|
||||||
"Ctrl+space": p2,
|
|
||||||
}
|
|
||||||
|
|
||||||
PIPELINE_OPTIONS = {
|
|
||||||
"Ctrl+space": {"failure_policy": "strict"},
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
),
|
|
||||||
encoding="utf-8",
|
|
||||||
)
|
|
||||||
|
|
||||||
bindings = load_bindings(
|
|
||||||
path=path,
|
|
||||||
default_hotkey="Cmd+m",
|
|
||||||
default_handler_factory=lambda: (lambda _audio, _lib: "default"),
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertEqual(sorted(bindings.keys()), ["Cmd+m", "Ctrl+space"])
|
|
||||||
self.assertEqual(bindings["Cmd+m"].options.failure_policy, "best_effort")
|
|
||||||
self.assertEqual(bindings["Ctrl+space"].options.failure_policy, "strict")
|
|
||||||
|
|
||||||
def test_invalid_options_fail(self):
|
|
||||||
with tempfile.TemporaryDirectory() as td:
|
|
||||||
path = Path(td) / "pipelines.py"
|
|
||||||
path.write_text(
|
|
||||||
textwrap.dedent(
|
|
||||||
"""
|
|
||||||
def p(audio, lib):
|
|
||||||
return "ok"
|
|
||||||
HOTKEY_PIPELINES = {"Cmd+m": p}
|
|
||||||
PIPELINE_OPTIONS = {"Cmd+m": {"failure_policy": "invalid"}}
|
|
||||||
"""
|
|
||||||
),
|
|
||||||
encoding="utf-8",
|
|
||||||
)
|
|
||||||
|
|
||||||
with self.assertRaisesRegex(ValueError, "failure_policy"):
|
|
||||||
load_bindings(
|
|
||||||
path=path,
|
|
||||||
default_hotkey="Cmd+m",
|
|
||||||
default_handler_factory=lambda: (lambda _audio, _lib: "default"),
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_fingerprint_changes_with_content(self):
|
|
||||||
with tempfile.TemporaryDirectory() as td:
|
|
||||||
path = Path(td) / "pipelines.py"
|
|
||||||
self.assertIsNone(fingerprint(path))
|
|
||||||
path.write_text("HOTKEY_PIPELINES = {}", encoding="utf-8")
|
|
||||||
first = fingerprint(path)
|
|
||||||
path.write_text("HOTKEY_PIPELINES = {'Cmd+m': lambda a, l: ''}", encoding="utf-8")
|
|
||||||
second = fingerprint(path)
|
|
||||||
self.assertIsNotNone(first)
|
|
||||||
self.assertIsNotNone(second)
|
|
||||||
self.assertNotEqual(first, second)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
unittest.main()
|
|
||||||
|
|
@ -7,17 +7,18 @@ SRC = ROOT / "src"
|
||||||
if str(SRC) not in sys.path:
|
if str(SRC) not in sys.path:
|
||||||
sys.path.insert(0, str(SRC))
|
sys.path.insert(0, str(SRC))
|
||||||
|
|
||||||
from config import VocabularyConfig, VocabularyReplacement
|
from config import DomainInferenceConfig, VocabularyConfig, VocabularyReplacement
|
||||||
from vocabulary import VocabularyEngine
|
from vocabulary import DOMAIN_GENERAL, VocabularyEngine
|
||||||
|
|
||||||
|
|
||||||
class VocabularyEngineTests(unittest.TestCase):
|
class VocabularyEngineTests(unittest.TestCase):
|
||||||
def _engine(self, replacements=None, terms=None):
|
def _engine(self, replacements=None, terms=None, domain_enabled=True):
|
||||||
vocab = VocabularyConfig(
|
vocab = VocabularyConfig(
|
||||||
replacements=replacements or [],
|
replacements=replacements or [],
|
||||||
terms=terms or [],
|
terms=terms or [],
|
||||||
)
|
)
|
||||||
return VocabularyEngine(vocab)
|
domain = DomainInferenceConfig(enabled=domain_enabled)
|
||||||
|
return VocabularyEngine(vocab, domain)
|
||||||
|
|
||||||
def test_boundary_aware_replacement(self):
|
def test_boundary_aware_replacement(self):
|
||||||
engine = self._engine(
|
engine = self._engine(
|
||||||
|
|
@ -49,5 +50,27 @@ class VocabularyEngineTests(unittest.TestCase):
|
||||||
self.assertLessEqual(len(hotwords), 1024)
|
self.assertLessEqual(len(hotwords), 1024)
|
||||||
self.assertLessEqual(len(prompt), 600)
|
self.assertLessEqual(len(prompt), 600)
|
||||||
|
|
||||||
|
def test_domain_inference_general_fallback(self):
|
||||||
|
engine = self._engine()
|
||||||
|
result = engine.infer_domain("please call me later")
|
||||||
|
|
||||||
|
self.assertEqual(result.name, DOMAIN_GENERAL)
|
||||||
|
self.assertEqual(result.confidence, 0.0)
|
||||||
|
|
||||||
|
def test_domain_inference_for_technical_text(self):
|
||||||
|
engine = self._engine(terms=["Docker", "Systemd"])
|
||||||
|
result = engine.infer_domain("restart Docker and systemd service on prod")
|
||||||
|
|
||||||
|
self.assertNotEqual(result.name, DOMAIN_GENERAL)
|
||||||
|
self.assertGreater(result.confidence, 0.0)
|
||||||
|
|
||||||
|
def test_domain_inference_can_be_disabled(self):
|
||||||
|
engine = self._engine(domain_enabled=False)
|
||||||
|
result = engine.infer_domain("please restart docker")
|
||||||
|
|
||||||
|
self.assertEqual(result.name, DOMAIN_GENERAL)
|
||||||
|
self.assertEqual(result.confidence, 0.0)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue