Revert "Add pipeline engine and remove legacy compatibility paths"

This commit is contained in:
Thales Maciel 2026-02-26 12:54:47 -03:00
parent e221d49020
commit 5b38cc7dcd
18 changed files with 399 additions and 1523 deletions

View file

@ -80,7 +80,8 @@ Create `~/.config/aman/config.json` (or let `aman` create it automatically on fi
{ "from": "docker", "to": "Docker" } { "from": "docker", "to": "Docker" }
], ],
"terms": ["Systemd", "Kubernetes"] "terms": ["Systemd", "Kubernetes"]
} },
"domain_inference": { "enabled": true }
} }
``` ```
@ -91,9 +92,6 @@ Hotkey notes:
- Use one key plus optional modifiers (for example `Cmd+m`, `Super+m`, `Ctrl+space`). - Use one key plus optional modifiers (for example `Cmd+m`, `Super+m`, `Ctrl+space`).
- `Super` and `Cmd` are equivalent aliases for the same modifier. - `Super` and `Cmd` are equivalent aliases for the same modifier.
- Invalid hotkey syntax in config prevents startup/reload.
- When `~/.config/aman/pipelines.py` exists, hotkeys come from `HOTKEY_PIPELINES`.
- `daemon.hotkey` is used as the fallback/default hotkey only when no pipelines file is present.
AI cleanup is always enabled and uses the locked local Llama-3.2-3B GGUF model AI cleanup is always enabled and uses the locked local Llama-3.2-3B GGUF model
downloaded to `~/.cache/aman/models/` during daemon initialization. downloaded to `~/.cache/aman/models/` during daemon initialization.
@ -109,6 +107,11 @@ Vocabulary correction:
- Wildcards are intentionally rejected (`*`, `?`, `[`, `]`, `{`, `}`) to avoid ambiguous rules. - Wildcards are intentionally rejected (`*`, `?`, `[`, `]`, `{`, `}`) to avoid ambiguous rules.
- Rules are deduplicated case-insensitively; conflicting replacements are rejected. - Rules are deduplicated case-insensitively; conflicting replacements are rejected.
Domain inference:
- Domain context is advisory only and is used to improve cleanup prompts.
- When confidence is low, it falls back to `general` context.
STT hinting: STT hinting:
- Vocabulary is passed to Whisper as `hotwords`/`initial_prompt` only when those - Vocabulary is passed to Whisper as `hotwords`/`initial_prompt` only when those
@ -131,12 +134,6 @@ systemctl --user enable --now aman
- Press it again to stop and run STT. - Press it again to stop and run STT.
- Press `Esc` while recording to cancel without processing. - Press `Esc` while recording to cancel without processing.
- Transcript contents are logged only when `-v/--verbose` is used. - Transcript contents are logged only when `-v/--verbose` is used.
- Config changes are hot-reloaded automatically (polled every 1 second).
- `~/.config/aman/pipelines.py` changes are hot-reloaded automatically (polled every 1 second).
- Send `SIGHUP` to force an immediate reload of config and pipelines:
`systemctl --user kill -s HUP aman` (or send `HUP` to the process directly).
- Reloads are applied when the daemon is idle; invalid updates are rejected and the last valid config stays active.
- Reload success/failure is logged, and desktop notifications are shown when available.
Wayland note: Wayland note:
@ -152,77 +149,6 @@ AI processing:
- Local llama.cpp model only (no remote provider configuration). - Local llama.cpp model only (no remote provider configuration).
## Pipelines API
`aman` is split into:
- shell daemon: hotkeys, recording/cancel, and desktop injection
- pipeline engine: `lib.transcribe(...)` and `lib.llm(...)`
- pipeline implementation: Python callables mapped per hotkey
Pipeline file path:
- `~/.config/aman/pipelines.py`
- You can start from [`pipelines.example.py`](./pipelines.example.py).
- If `pipelines.py` is missing, `aman` uses a built-in reference pipeline bound to `daemon.hotkey`.
- If `pipelines.py` exists but is invalid, startup fails fast.
- Pipelines are hot-reloaded automatically when the module file changes.
- Send `SIGHUP` to force an immediate reload of both config and pipelines.
Expected module exports:
```python
HOTKEY_PIPELINES = {
"Super+m": my_pipeline,
"Super+Shift+m": caps_pipeline,
}
PIPELINE_OPTIONS = {
"Super+Shift+m": {"failure_policy": "strict"}, # optional
}
```
Pipeline callable signature:
```python
def my_pipeline(audio, lib) -> str:
text = lib.transcribe(audio)
context = lib.llm(
system_prompt="context system prompt",
user_prompt=f"Transcript: {text}",
)
out = lib.llm(
system_prompt="amanuensis prompt",
user_prompt=f"context={context}\ntext={text}",
)
return out
```
`lib` API:
- `lib.transcribe(audio, hints=None, whisper_opts=None) -> str`
- `lib.llm(system_prompt=..., user_prompt=..., llm_opts=None) -> str`
Failure policy options:
- `best_effort` (default): pipeline errors return empty output
- `strict`: pipeline errors abort the current run
Validation:
- `HOTKEY_PIPELINES` must be a non-empty dictionary.
- Every hotkey key must be a non-empty string.
- Every pipeline value must be callable.
- `PIPELINE_OPTIONS` must be a dictionary when provided.
Reference behavior:
- The built-in fallback pipeline (used when `pipelines.py` is missing) uses `lib.llm(...)` twice:
- first to infer context
- second to run the amanuensis rewrite
- The second pass requests JSON output and expects `{"cleaned_text": "..."}`.
- Deterministic dictionary replacements are then applied as part of that reference implementation.
Control: Control:
```bash ```bash

View file

@ -1,50 +0,0 @@
import json
CONTEXT_PROMPT = (
"Return a concise plain-text context hint (max 12 words) "
"for the transcript domain and style."
)
AMANUENSIS_PROMPT = (
"Rewrite the transcript into clean prose without changing intent. "
"Return only the final text."
)
def default_pipeline(audio, lib) -> str:
text = lib.transcribe(audio)
if not text:
return ""
context = lib.llm(
system_prompt=CONTEXT_PROMPT,
user_prompt=json.dumps({"transcript": text}, ensure_ascii=False),
llm_opts={"temperature": 0.0},
).strip()
output = lib.llm(
system_prompt=AMANUENSIS_PROMPT,
user_prompt=json.dumps(
{"transcript": text, "context": context},
ensure_ascii=False,
),
llm_opts={"temperature": 0.0},
)
return output.strip()
def caps_pipeline(audio, lib) -> str:
return lib.transcribe(audio).upper()
HOTKEY_PIPELINES = {
"Super+m": default_pipeline,
"Super+Shift+m": caps_pipeline,
}
PIPELINE_OPTIONS = {
"Super+m": {"failure_policy": "best_effort"},
"Super+Shift+m": {"failure_policy": "strict"},
}

View file

@ -76,41 +76,18 @@ class LlamaProcessor:
if cleaned_dictionary: if cleaned_dictionary:
request_payload["dictionary"] = cleaned_dictionary request_payload["dictionary"] = cleaned_dictionary
content = self.chat(
system_prompt=SYSTEM_PROMPT,
user_prompt=json.dumps(request_payload, ensure_ascii=False),
llm_opts={"temperature": 0.0, "response_format": {"type": "json_object"}},
)
return _extract_cleaned_text_from_raw(content)
def chat(
self,
*,
system_prompt: str,
user_prompt: str,
llm_opts: dict[str, Any] | None = None,
) -> str:
opts = dict(llm_opts or {})
temperature = float(opts.pop("temperature", 0.0))
response_format = opts.pop("response_format", None)
if opts:
unknown = ", ".join(sorted(opts.keys()))
raise ValueError(f"unsupported llm options: {unknown}")
kwargs: dict[str, Any] = { kwargs: dict[str, Any] = {
"messages": [ "messages": [
{"role": "system", "content": system_prompt}, {"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_prompt}, {"role": "user", "content": json.dumps(request_payload, ensure_ascii=False)},
], ],
"temperature": temperature, "temperature": 0.0,
} }
if response_format is not None and _supports_response_format( if _supports_response_format(self.client.create_chat_completion):
self.client.create_chat_completion kwargs["response_format"] = {"type": "json_object"}
):
kwargs["response_format"] = response_format
response = self.client.create_chat_completion(**kwargs) response = self.client.create_chat_completion(**kwargs)
return _extract_chat_text(response) return _extract_cleaned_text(response)
def ensure_model(): def ensure_model():
@ -171,10 +148,6 @@ def _extract_chat_text(payload: Any) -> str:
def _extract_cleaned_text(payload: Any) -> str: def _extract_cleaned_text(payload: Any) -> str:
raw = _extract_chat_text(payload) raw = _extract_chat_text(payload)
return _extract_cleaned_text_from_raw(raw)
def _extract_cleaned_text_from_raw(raw: str) -> str:
try: try:
parsed = json.loads(raw) parsed = json.loads(raw)
except json.JSONDecodeError as exc: except json.JSONDecodeError as exc:

View file

@ -3,7 +3,6 @@ from __future__ import annotations
import argparse import argparse
import errno import errno
import hashlib
import inspect import inspect
import json import json
import logging import logging
@ -13,21 +12,12 @@ import sys
import threading import threading
import time import time
from pathlib import Path from pathlib import Path
from typing import Any, Callable from typing import Any
from aiprocess import LlamaProcessor, SYSTEM_PROMPT from aiprocess import LlamaProcessor
from config import Config, load, redacted_dict from config import Config, load, redacted_dict
from constants import ( from constants import RECORD_TIMEOUT_SEC, STT_LANGUAGE
CONFIG_RELOAD_POLL_INTERVAL_SEC,
DEFAULT_CONFIG_PATH,
RECORD_TIMEOUT_SEC,
STT_LANGUAGE,
)
from desktop import get_desktop_adapter from desktop import get_desktop_adapter
from engine import Engine, PipelineBinding, PipelineLib, PipelineOptions
from notify import DesktopNotifier
from pipelines_runtime import DEFAULT_PIPELINES_PATH, fingerprint as pipelines_fingerprint
from pipelines_runtime import load_bindings
from recorder import start_recording as start_audio_recording from recorder import start_recording as start_audio_recording
from recorder import stop_recording as stop_audio_recording from recorder import stop_recording as stop_audio_recording
from vocabulary import VocabularyEngine from vocabulary import VocabularyEngine
@ -65,52 +55,6 @@ def _compute_type(device: str) -> str:
return "int8" return "int8"
def _replacement_view(cfg: Config) -> list[tuple[str, str]]:
return [(item.source, item.target) for item in cfg.vocabulary.replacements]
def _config_fields_changed(current: Config, candidate: Config) -> list[str]:
changed: list[str] = []
if current.daemon.hotkey != candidate.daemon.hotkey:
changed.append("daemon.hotkey")
if current.recording.input != candidate.recording.input:
changed.append("recording.input")
if current.stt.model != candidate.stt.model:
changed.append("stt.model")
if current.stt.device != candidate.stt.device:
changed.append("stt.device")
if current.injection.backend != candidate.injection.backend:
changed.append("injection.backend")
if (
current.injection.remove_transcription_from_clipboard
!= candidate.injection.remove_transcription_from_clipboard
):
changed.append("injection.remove_transcription_from_clipboard")
if _replacement_view(current) != _replacement_view(candidate):
changed.append("vocabulary.replacements")
if current.vocabulary.terms != candidate.vocabulary.terms:
changed.append("vocabulary.terms")
return changed
CONTEXT_SYSTEM_PROMPT = (
"You infer concise writing context from transcript text.\n"
"Return a short plain-text hint (max 12 words) that describes the likely domain and style.\n"
"Do not add punctuation, quotes, XML tags, or markdown."
)
def _extract_cleaned_text_from_llm_raw(raw: str) -> str:
parsed = json.loads(raw)
if isinstance(parsed, str):
return parsed
if isinstance(parsed, dict):
cleaned_text = parsed.get("cleaned_text")
if isinstance(cleaned_text, str):
return cleaned_text
raise RuntimeError("unexpected ai output format: missing cleaned_text")
class Daemon: class Daemon:
def __init__(self, cfg: Config, desktop, *, verbose: bool = False): def __init__(self, cfg: Config, desktop, *, verbose: bool = False):
self.cfg = cfg self.cfg = cfg
@ -122,34 +66,17 @@ class Daemon:
self.stream = None self.stream = None
self.record = None self.record = None
self.timer: threading.Timer | None = None self.timer: threading.Timer | None = None
self._config_apply_in_progress = False self.model = _build_whisper_model(
self._active_hotkey: str | None = None cfg.stt.model,
self._pending_worker_hotkey: str | None = None cfg.stt.device,
self._pipeline_bindings: dict[str, PipelineBinding] = {} )
self.model = _build_whisper_model(cfg.stt.model, cfg.stt.device)
logging.info("initializing ai processor") logging.info("initializing ai processor")
self.ai_processor = LlamaProcessor(verbose=self.verbose) self.ai_processor = LlamaProcessor(verbose=self.verbose)
logging.info("ai processor ready") logging.info("ai processor ready")
self.log_transcript = verbose self.log_transcript = verbose
self.vocabulary = VocabularyEngine(cfg.vocabulary) self.vocabulary = VocabularyEngine(cfg.vocabulary, cfg.domain_inference)
self._stt_hint_kwargs_cache: dict[str, Any] | None = None self._stt_hint_kwargs_cache: dict[str, Any] | None = None
self.lib = PipelineLib(
transcribe_fn=self._transcribe_with_options,
llm_fn=self._llm_with_options,
)
self.engine = Engine(self.lib)
self.set_pipeline_bindings(
{
cfg.daemon.hotkey: PipelineBinding(
hotkey=cfg.daemon.hotkey,
handler=self.build_reference_pipeline(),
options=PipelineOptions(failure_policy="best_effort"),
)
}
)
def set_state(self, state: str): def set_state(self, state: str):
with self.lock: with self.lock:
prev = self.state prev = self.state
@ -166,74 +93,16 @@ class Daemon:
def request_shutdown(self): def request_shutdown(self):
self._shutdown_requested.set() self._shutdown_requested.set()
def current_hotkeys(self) -> list[str]: def toggle(self):
with self.lock:
return list(self._pipeline_bindings.keys())
def set_pipeline_bindings(self, bindings: dict[str, PipelineBinding]) -> None:
if not bindings:
raise ValueError("at least one pipeline binding is required")
with self.lock:
self._pipeline_bindings = dict(bindings)
if self._active_hotkey and self._active_hotkey not in self._pipeline_bindings:
self._active_hotkey = None
def build_hotkey_callbacks(
self,
hotkeys: list[str],
*,
dry_run: bool,
) -> dict[str, Callable[[], None]]:
callbacks: dict[str, Callable[[], None]] = {}
for hotkey in hotkeys:
if dry_run:
callbacks[hotkey] = (lambda hk=hotkey: logging.info("hotkey pressed: %s (dry-run)", hk))
else:
callbacks[hotkey] = (lambda hk=hotkey: self.handle_hotkey(hk))
return callbacks
def apply_pipeline_bindings(
self,
bindings: dict[str, PipelineBinding],
callbacks: dict[str, Callable[[], None]],
) -> tuple[str, str]:
with self.lock:
if self._shutdown_requested.is_set():
return "deferred", "shutdown in progress"
if self._config_apply_in_progress:
return "deferred", "reload in progress"
if self.state != State.IDLE:
return "deferred", f"daemon is busy ({self.state})"
self._config_apply_in_progress = True
try:
self.desktop.set_hotkeys(callbacks)
with self.lock:
if self._shutdown_requested.is_set():
return "deferred", "shutdown in progress"
if self.state != State.IDLE:
return "deferred", f"daemon is busy ({self.state})"
self._pipeline_bindings = dict(bindings)
return "applied", ""
except Exception as exc:
return "error", str(exc)
finally:
with self.lock:
self._config_apply_in_progress = False
def handle_hotkey(self, hotkey: str):
should_stop = False should_stop = False
with self.lock: with self.lock:
if self._shutdown_requested.is_set(): if self._shutdown_requested.is_set():
logging.info("shutdown in progress, trigger ignored") logging.info("shutdown in progress, trigger ignored")
return return
if self._config_apply_in_progress:
logging.info("reload in progress, trigger ignored")
return
if self.state == State.IDLE: if self.state == State.IDLE:
self._active_hotkey = hotkey
self._start_recording_locked() self._start_recording_locked()
return return
if self.state == State.RECORDING and self._active_hotkey == hotkey: if self.state == State.RECORDING:
should_stop = True should_stop = True
else: else:
logging.info("busy (%s), trigger ignored", self.state) logging.info("busy (%s), trigger ignored", self.state)
@ -244,9 +113,6 @@ class Daemon:
if self.state != State.IDLE: if self.state != State.IDLE:
logging.info("busy (%s), trigger ignored", self.state) logging.info("busy (%s), trigger ignored", self.state)
return return
if self._config_apply_in_progress:
logging.info("reload in progress, trigger ignored")
return
try: try:
stream, record = start_audio_recording(self.cfg.recording.input) stream, record = start_audio_recording(self.cfg.recording.input)
except Exception as exc: except Exception as exc:
@ -267,17 +133,10 @@ class Daemon:
def _timeout_stop(self): def _timeout_stop(self):
self.stop_recording(trigger="timeout") self.stop_recording(trigger="timeout")
def _start_stop_worker( def _start_stop_worker(self, stream: Any, record: Any, trigger: str, process_audio: bool):
self,
stream: Any,
record: Any,
trigger: str,
process_audio: bool,
):
active_hotkey = self._pending_worker_hotkey or self.cfg.daemon.hotkey
threading.Thread( threading.Thread(
target=self._stop_and_process, target=self._stop_and_process,
args=(stream, record, trigger, process_audio, active_hotkey), args=(stream, record, trigger, process_audio),
daemon=True, daemon=True,
).start() ).start()
@ -286,8 +145,6 @@ class Daemon:
return None return None
stream = self.stream stream = self.stream
record = self.record record = self.record
active_hotkey = self._active_hotkey or self.cfg.daemon.hotkey
self._active_hotkey = None
self.stream = None self.stream = None
self.record = None self.record = None
if self.timer: if self.timer:
@ -301,17 +158,9 @@ class Daemon:
logging.warning("recording resources are unavailable during stop") logging.warning("recording resources are unavailable during stop")
self.state = State.IDLE self.state = State.IDLE
return None return None
return stream, record, active_hotkey return stream, record
def _stop_and_process( def _stop_and_process(self, stream: Any, record: Any, trigger: str, process_audio: bool):
self,
stream: Any,
record: Any,
trigger: str,
process_audio: bool,
active_hotkey: str | None = None,
):
hotkey = active_hotkey or self.cfg.daemon.hotkey
logging.info("stopping recording (%s)", trigger) logging.info("stopping recording (%s)", trigger)
try: try:
audio = stop_audio_recording(stream, record) audio = stop_audio_recording(stream, record)
@ -330,17 +179,44 @@ class Daemon:
return return
try: try:
logging.info("pipeline started (%s)", hotkey) self.set_state(State.STT)
text = self._run_pipeline(audio, hotkey).strip() logging.info("stt started")
text = self._transcribe(audio)
except Exception as exc: except Exception as exc:
logging.error("pipeline failed: %s", exc) logging.error("stt failed: %s", exc)
self.set_state(State.IDLE) self.set_state(State.IDLE)
return return
text = (text or "").strip()
if not text: if not text:
self.set_state(State.IDLE) self.set_state(State.IDLE)
return return
if self.log_transcript:
logging.debug("stt: %s", text)
else:
logging.info("stt produced %d chars", len(text))
domain = self.vocabulary.infer_domain(text)
if not self._shutdown_requested.is_set():
self.set_state(State.PROCESSING)
logging.info("ai processing started")
try:
processor = self._get_ai_processor()
ai_text = processor.process(
text,
lang=STT_LANGUAGE,
dictionary_context=self.vocabulary.build_ai_dictionary_context(),
domain_name=domain.name,
domain_confidence=domain.confidence,
)
if ai_text and ai_text.strip():
text = ai_text.strip()
except Exception as exc:
logging.error("ai process failed: %s", exc)
text = self.vocabulary.apply_deterministic_replacements(text).strip()
if self.log_transcript: if self.log_transcript:
logging.debug("processed: %s", text) logging.debug("processed: %s", text)
else: else:
@ -366,23 +242,13 @@ class Daemon:
finally: finally:
self.set_state(State.IDLE) self.set_state(State.IDLE)
def _run_pipeline(self, audio: Any, hotkey: str) -> str:
with self.lock:
binding = self._pipeline_bindings.get(hotkey)
if binding is None and self._pipeline_bindings:
binding = next(iter(self._pipeline_bindings.values()))
if binding is None:
return ""
return self.engine.run(binding, audio)
def stop_recording(self, *, trigger: str = "user", process_audio: bool = True): def stop_recording(self, *, trigger: str = "user", process_audio: bool = True):
payload = None payload = None
with self.lock: with self.lock:
payload = self._begin_stop_locked() payload = self._begin_stop_locked()
if payload is None: if payload is None:
return return
stream, record, active_hotkey = payload stream, record = payload
self._pending_worker_hotkey = active_hotkey
self._start_stop_worker(stream, record, trigger, process_audio) self._start_stop_worker(stream, record, trigger, process_audio)
def cancel_recording(self): def cancel_recording(self):
@ -404,129 +270,40 @@ class Daemon:
time.sleep(0.05) time.sleep(0.05)
return self.get_state() == State.IDLE return self.get_state() == State.IDLE
def build_reference_pipeline(self) -> Callable[[Any, PipelineLib], str]:
def _reference_pipeline(audio: Any, lib: PipelineLib) -> str:
text = (lib.transcribe(audio) or "").strip()
if not text:
return ""
dictionary_context = self.vocabulary.build_ai_dictionary_context()
context_prompt = {
"transcript": text,
"dictionary": dictionary_context,
}
context = (
lib.llm(
system_prompt=CONTEXT_SYSTEM_PROMPT,
user_prompt=json.dumps(context_prompt, ensure_ascii=False),
llm_opts={"temperature": 0.0},
)
.strip()
)
payload: dict[str, Any] = {
"language": STT_LANGUAGE,
"transcript": text,
"context": context,
}
if dictionary_context:
payload["dictionary"] = dictionary_context
ai_raw = lib.llm(
system_prompt=SYSTEM_PROMPT,
user_prompt=json.dumps(payload, ensure_ascii=False),
llm_opts={"temperature": 0.0, "response_format": {"type": "json_object"}},
)
cleaned_text = _extract_cleaned_text_from_llm_raw(ai_raw)
return self.vocabulary.apply_deterministic_replacements(cleaned_text).strip()
return _reference_pipeline
def _transcribe(self, audio) -> str: def _transcribe(self, audio) -> str:
return self._transcribe_with_options(audio)
def _transcribe_with_options(
self,
audio: Any,
*,
hints: list[str] | None = None,
whisper_opts: dict[str, Any] | None = None,
) -> str:
kwargs: dict[str, Any] = { kwargs: dict[str, Any] = {
"language": STT_LANGUAGE, "language": STT_LANGUAGE,
"vad_filter": True, "vad_filter": True,
} }
kwargs.update(self._stt_hint_kwargs(hints=hints)) kwargs.update(self._stt_hint_kwargs())
if whisper_opts:
kwargs.update(whisper_opts)
segments, _info = self.model.transcribe(audio, **kwargs) segments, _info = self.model.transcribe(audio, **kwargs)
parts = [] parts = []
for seg in segments: for seg in segments:
text = (seg.text or "").strip() text = (seg.text or "").strip()
if text: if text:
parts.append(text) parts.append(text)
out = " ".join(parts).strip() return " ".join(parts).strip()
if self.log_transcript:
logging.debug("stt: %s", out)
else:
logging.info("stt produced %d chars", len(out))
return out
def _llm_with_options(
self,
*,
system_prompt: str,
user_prompt: str,
llm_opts: dict[str, Any] | None = None,
) -> str:
if self.get_state() != State.PROCESSING:
self.set_state(State.PROCESSING)
logging.info("llm processing started")
processor = self._get_ai_processor()
return processor.chat(
system_prompt=system_prompt,
user_prompt=user_prompt,
llm_opts=llm_opts,
)
def _get_ai_processor(self) -> LlamaProcessor: def _get_ai_processor(self) -> LlamaProcessor:
if self.ai_processor is None: if self.ai_processor is None:
raise RuntimeError("ai processor is not initialized") raise RuntimeError("ai processor is not initialized")
return self.ai_processor return self.ai_processor
def _stt_hint_kwargs(self, *, hints: list[str] | None = None) -> dict[str, Any]: def _stt_hint_kwargs(self) -> dict[str, Any]:
if not hints and self._stt_hint_kwargs_cache is not None: if self._stt_hint_kwargs_cache is not None:
return self._stt_hint_kwargs_cache return self._stt_hint_kwargs_cache
hotwords, initial_prompt = self.vocabulary.build_stt_hints() hotwords, initial_prompt = self.vocabulary.build_stt_hints()
extra_hints = [item.strip() for item in (hints or []) if isinstance(item, str) and item.strip()]
if extra_hints:
words = [item.strip() for item in hotwords.split(",") if item.strip()]
words.extend(extra_hints)
deduped: list[str] = []
seen: set[str] = set()
for item in words:
key = item.casefold()
if key in seen:
continue
seen.add(key)
deduped.append(item)
merged = ", ".join(deduped)
hotwords = merged[:1024]
if hotwords:
initial_prompt = f"Preferred vocabulary: {hotwords}"[:600]
if not hotwords and not initial_prompt: if not hotwords and not initial_prompt:
if not hints: self._stt_hint_kwargs_cache = {}
self._stt_hint_kwargs_cache = {} return self._stt_hint_kwargs_cache
return self._stt_hint_kwargs_cache
return {}
try: try:
signature = inspect.signature(self.model.transcribe) signature = inspect.signature(self.model.transcribe)
except (TypeError, ValueError): except (TypeError, ValueError):
logging.debug("stt signature inspection failed; skipping hints") logging.debug("stt signature inspection failed; skipping hints")
if not hints: self._stt_hint_kwargs_cache = {}
self._stt_hint_kwargs_cache = {} return self._stt_hint_kwargs_cache
return self._stt_hint_kwargs_cache
return {}
params = signature.parameters params = signature.parameters
kwargs: dict[str, Any] = {} kwargs: dict[str, Any] = {}
@ -536,288 +313,8 @@ class Daemon:
kwargs["initial_prompt"] = initial_prompt kwargs["initial_prompt"] = initial_prompt
if not kwargs: if not kwargs:
logging.debug("stt hint arguments are not supported by this whisper runtime") logging.debug("stt hint arguments are not supported by this whisper runtime")
if not hints: self._stt_hint_kwargs_cache = kwargs
self._stt_hint_kwargs_cache = kwargs return self._stt_hint_kwargs_cache
return self._stt_hint_kwargs_cache
return kwargs
def try_apply_config(self, candidate: Config) -> tuple[str, list[str], str]:
with self.lock:
if self._shutdown_requested.is_set():
return "deferred", [], "shutdown in progress"
if self._config_apply_in_progress:
return "deferred", [], "config reload already in progress"
if self.state != State.IDLE:
return "deferred", [], f"daemon is busy ({self.state})"
self._config_apply_in_progress = True
current_cfg = self.cfg
try:
changed = _config_fields_changed(current_cfg, candidate)
if not changed:
return "applied", [], ""
next_model = None
if (
current_cfg.stt.model != candidate.stt.model
or current_cfg.stt.device != candidate.stt.device
):
try:
next_model = _build_whisper_model(candidate.stt.model, candidate.stt.device)
except Exception as exc:
return "error", [], f"stt model reload failed: {exc}"
try:
next_vocab = VocabularyEngine(candidate.vocabulary)
except Exception as exc:
return "error", [], f"vocabulary reload failed: {exc}"
with self.lock:
if self._shutdown_requested.is_set():
return "deferred", [], "shutdown in progress"
if self.state != State.IDLE:
return "deferred", [], f"daemon is busy ({self.state})"
self.cfg = candidate
if next_model is not None:
self.model = next_model
self.vocabulary = next_vocab
self._stt_hint_kwargs_cache = None
return "applied", changed, ""
finally:
with self.lock:
self._config_apply_in_progress = False
class ConfigReloader:
def __init__(
self,
*,
daemon: Daemon,
config_path: Path,
notifier: DesktopNotifier | None = None,
poll_interval_sec: float = CONFIG_RELOAD_POLL_INTERVAL_SEC,
):
self.daemon = daemon
self.config_path = config_path
self.notifier = notifier
self.poll_interval_sec = poll_interval_sec
self._stop_event = threading.Event()
self._request_lock = threading.Lock()
self._pending_reload_reason = ""
self._pending_force = False
self._last_seen_fingerprint = self._fingerprint()
self._thread: threading.Thread | None = None
def start(self) -> None:
if self._thread is not None:
return
self._thread = threading.Thread(target=self._run, daemon=True, name="config-reloader")
self._thread.start()
def stop(self, timeout: float = 2.0) -> None:
self._stop_event.set()
if self._thread is not None:
self._thread.join(timeout=timeout)
self._thread = None
def request_reload(self, reason: str, *, force: bool = False) -> None:
with self._request_lock:
if self._pending_reload_reason:
self._pending_reload_reason = f"{self._pending_reload_reason}; {reason}"
else:
self._pending_reload_reason = reason
self._pending_force = self._pending_force or force
logging.debug("config reload requested: %s (force=%s)", reason, force)
def tick(self) -> None:
self._detect_file_change()
self._apply_if_pending()
def _run(self) -> None:
while not self._stop_event.wait(self.poll_interval_sec):
self.tick()
def _detect_file_change(self) -> None:
fingerprint = self._fingerprint()
if fingerprint == self._last_seen_fingerprint:
return
self._last_seen_fingerprint = fingerprint
self.request_reload("file change detected", force=False)
def _apply_if_pending(self) -> None:
with self._request_lock:
reason = self._pending_reload_reason
force_reload = self._pending_force
if not reason:
return
if self.daemon.get_state() != State.IDLE:
logging.debug("config reload deferred; daemon state=%s", self.daemon.get_state())
return
try:
candidate = load(str(self.config_path))
except Exception as exc:
self._clear_pending()
msg = f"config reload failed ({reason}): {exc}"
logging.error(msg)
self._notify("Config Reload Failed", str(exc), error=True)
return
status, changed, error = self.daemon.try_apply_config(candidate)
if status == "deferred":
logging.debug("config reload deferred during apply: %s", error)
return
self._clear_pending()
if status == "error":
msg = f"config reload failed ({reason}): {error}"
logging.error(msg)
self._notify("Config Reload Failed", error, error=True)
return
if changed:
msg = ", ".join(changed)
logging.info("config reloaded (%s): %s", reason, msg)
self._notify("Config Reloaded", f"Applied: {msg}", error=False)
return
if force_reload:
logging.info("config reloaded (%s): no effective changes", reason)
self._notify("Config Reloaded", "No effective changes", error=False)
def _clear_pending(self) -> None:
with self._request_lock:
self._pending_reload_reason = ""
self._pending_force = False
def _fingerprint(self) -> str | None:
try:
data = self.config_path.read_bytes()
return hashlib.sha256(data).hexdigest()
except FileNotFoundError:
return None
except OSError as exc:
logging.warning("config fingerprint failed (%s): %s", self.config_path, exc)
return None
def _notify(self, title: str, body: str, *, error: bool) -> None:
if self.notifier is None:
return
sent = self.notifier.send(title, body, error=error)
if not sent:
logging.debug("desktop notifications unavailable: %s", title)
class PipelineReloader:
def __init__(
self,
*,
daemon: Daemon,
pipelines_path: Path,
notifier: DesktopNotifier | None = None,
dry_run: bool = False,
poll_interval_sec: float = CONFIG_RELOAD_POLL_INTERVAL_SEC,
):
self.daemon = daemon
self.pipelines_path = pipelines_path
self.notifier = notifier
self.dry_run = dry_run
self.poll_interval_sec = poll_interval_sec
self._stop_event = threading.Event()
self._request_lock = threading.Lock()
self._pending_reload_reason = ""
self._pending_force = False
self._last_seen_fingerprint = pipelines_fingerprint(self.pipelines_path)
self._thread: threading.Thread | None = None
def start(self) -> None:
if self._thread is not None:
return
self._thread = threading.Thread(target=self._run, daemon=True, name="pipeline-reloader")
self._thread.start()
def stop(self, timeout: float = 2.0) -> None:
self._stop_event.set()
if self._thread is not None:
self._thread.join(timeout=timeout)
self._thread = None
def request_reload(self, reason: str, *, force: bool = False) -> None:
with self._request_lock:
if self._pending_reload_reason:
self._pending_reload_reason = f"{self._pending_reload_reason}; {reason}"
else:
self._pending_reload_reason = reason
self._pending_force = self._pending_force or force
logging.debug("pipeline reload requested: %s (force=%s)", reason, force)
def tick(self) -> None:
self._detect_file_change()
self._apply_if_pending()
def _run(self) -> None:
while not self._stop_event.wait(self.poll_interval_sec):
self.tick()
def _detect_file_change(self) -> None:
fingerprint = pipelines_fingerprint(self.pipelines_path)
if fingerprint == self._last_seen_fingerprint:
return
self._last_seen_fingerprint = fingerprint
self.request_reload("file change detected", force=False)
def _apply_if_pending(self) -> None:
with self._request_lock:
reason = self._pending_reload_reason
force_reload = self._pending_force
if not reason:
return
if self.daemon.get_state() != State.IDLE:
logging.debug("pipeline reload deferred; daemon state=%s", self.daemon.get_state())
return
try:
bindings = load_bindings(
path=self.pipelines_path,
default_hotkey=self.daemon.cfg.daemon.hotkey,
default_handler_factory=self.daemon.build_reference_pipeline,
)
callbacks = self.daemon.build_hotkey_callbacks(
list(bindings.keys()),
dry_run=self.dry_run,
)
except Exception as exc:
self._clear_pending()
logging.error("pipeline reload failed (%s): %s", reason, exc)
self._notify("Pipelines Reload Failed", str(exc), error=True)
return
status, error = self.daemon.apply_pipeline_bindings(bindings, callbacks)
if status == "deferred":
logging.debug("pipeline reload deferred during apply: %s", error)
return
self._clear_pending()
if status == "error":
logging.error("pipeline reload failed (%s): %s", reason, error)
self._notify("Pipelines Reload Failed", error, error=True)
return
hotkeys = ", ".join(sorted(bindings.keys()))
logging.info("pipelines reloaded (%s): %s", reason, hotkeys)
self._notify("Pipelines Reloaded", f"Hotkeys: {hotkeys}", error=False)
if force_reload and not bindings:
logging.info("pipelines reloaded (%s): no hotkeys", reason)
def _clear_pending(self) -> None:
with self._request_lock:
self._pending_reload_reason = ""
self._pending_force = False
def _notify(self, title: str, body: str, *, error: bool) -> None:
if self.notifier is None:
return
sent = self.notifier.send(title, body, error=error)
if not sent:
logging.debug("desktop notifications unavailable: %s", title)
def _read_lock_pid(lock_file) -> str: def _read_lock_pid(lock_file) -> str:
@ -869,49 +366,19 @@ def main():
level=logging.DEBUG if args.verbose else logging.INFO, level=logging.DEBUG if args.verbose else logging.INFO,
format="aman: %(asctime)s %(levelname)s %(message)s", format="aman: %(asctime)s %(levelname)s %(message)s",
) )
config_path = Path(args.config) if args.config else DEFAULT_CONFIG_PATH cfg = load(args.config)
cfg = load(str(config_path))
_LOCK_HANDLE = _lock_single_instance() _LOCK_HANDLE = _lock_single_instance()
logging.info("hotkey: %s", cfg.daemon.hotkey) logging.info("hotkey: %s", cfg.daemon.hotkey)
logging.info( logging.info(
"config (%s):\n%s", "config (%s):\n%s",
str(config_path), args.config or str(Path.home() / ".config" / "aman" / "config.json"),
json.dumps(redacted_dict(cfg), indent=2), json.dumps(redacted_dict(cfg), indent=2),
) )
config_reloader = None
pipeline_reloader = None
try: try:
desktop = get_desktop_adapter() desktop = get_desktop_adapter()
daemon = Daemon(cfg, desktop, verbose=args.verbose) daemon = Daemon(cfg, desktop, verbose=args.verbose)
notifier = DesktopNotifier(app_name="aman")
config_reloader = ConfigReloader(
daemon=daemon,
config_path=config_path,
notifier=notifier,
poll_interval_sec=CONFIG_RELOAD_POLL_INTERVAL_SEC,
)
pipelines_path = DEFAULT_PIPELINES_PATH
initial_bindings = load_bindings(
path=pipelines_path,
default_hotkey=cfg.daemon.hotkey,
default_handler_factory=daemon.build_reference_pipeline,
)
initial_callbacks = daemon.build_hotkey_callbacks(
list(initial_bindings.keys()),
dry_run=args.dry_run,
)
status, error = daemon.apply_pipeline_bindings(initial_bindings, initial_callbacks)
if status != "applied":
raise RuntimeError(f"pipeline setup failed: {error}")
pipeline_reloader = PipelineReloader(
daemon=daemon,
pipelines_path=pipelines_path,
notifier=notifier,
dry_run=args.dry_run,
poll_interval_sec=CONFIG_RELOAD_POLL_INTERVAL_SEC,
)
except Exception as exc: except Exception as exc:
logging.error("startup failed: %s", exc) logging.error("startup failed: %s", exc)
raise SystemExit(1) raise SystemExit(1)
@ -923,10 +390,6 @@ def main():
return return
shutdown_once.set() shutdown_once.set()
logging.info("%s, shutting down", reason) logging.info("%s, shutting down", reason)
if config_reloader is not None:
config_reloader.stop(timeout=2.0)
if pipeline_reloader is not None:
pipeline_reloader.stop(timeout=2.0)
if not daemon.shutdown(timeout=5.0): if not daemon.shutdown(timeout=5.0):
logging.warning("timed out waiting for idle state during shutdown") logging.warning("timed out waiting for idle state during shutdown")
desktop.request_quit() desktop.request_quit()
@ -934,30 +397,15 @@ def main():
def handle_signal(_sig, _frame): def handle_signal(_sig, _frame):
threading.Thread(target=shutdown, args=("signal received",), daemon=True).start() threading.Thread(target=shutdown, args=("signal received",), daemon=True).start()
def handle_reload_signal(_sig, _frame):
if shutdown_once.is_set():
return
if config_reloader is not None:
threading.Thread(
target=lambda: config_reloader.request_reload("signal SIGHUP", force=True),
daemon=True,
).start()
if pipeline_reloader is not None:
threading.Thread(
target=lambda: pipeline_reloader.request_reload("signal SIGHUP", force=True),
daemon=True,
).start()
signal.signal(signal.SIGINT, handle_signal) signal.signal(signal.SIGINT, handle_signal)
signal.signal(signal.SIGTERM, handle_signal) signal.signal(signal.SIGTERM, handle_signal)
signal.signal(signal.SIGHUP, handle_reload_signal)
try: try:
desktop.start_hotkey_listener(
cfg.daemon.hotkey,
lambda: logging.info("hotkey pressed (dry-run)") if args.dry_run else daemon.toggle(),
)
desktop.start_cancel_listener(lambda: daemon.cancel_recording()) desktop.start_cancel_listener(lambda: daemon.cancel_recording())
if config_reloader is not None:
config_reloader.start()
if pipeline_reloader is not None:
pipeline_reloader.start()
except Exception as exc: except Exception as exc:
logging.error("hotkey setup failed: %s", exc) logging.error("hotkey setup failed: %s", exc)
raise SystemExit(1) raise SystemExit(1)
@ -965,10 +413,6 @@ def main():
try: try:
desktop.run_tray(daemon.get_state, lambda: shutdown("quit requested")) desktop.run_tray(daemon.get_state, lambda: shutdown("quit requested"))
finally: finally:
if config_reloader is not None:
config_reloader.stop(timeout=2.0)
if pipeline_reloader is not None:
pipeline_reloader.stop(timeout=2.0)
daemon.shutdown(timeout=1.0) daemon.shutdown(timeout=1.0)

View file

@ -51,6 +51,11 @@ class VocabularyConfig:
terms: list[str] = field(default_factory=list) terms: list[str] = field(default_factory=list)
@dataclass
class DomainInferenceConfig:
enabled: bool = True
@dataclass @dataclass
class Config: class Config:
daemon: DaemonConfig = field(default_factory=DaemonConfig) daemon: DaemonConfig = field(default_factory=DaemonConfig)
@ -58,6 +63,7 @@ class Config:
stt: SttConfig = field(default_factory=SttConfig) stt: SttConfig = field(default_factory=SttConfig)
injection: InjectionConfig = field(default_factory=InjectionConfig) injection: InjectionConfig = field(default_factory=InjectionConfig)
vocabulary: VocabularyConfig = field(default_factory=VocabularyConfig) vocabulary: VocabularyConfig = field(default_factory=VocabularyConfig)
domain_inference: DomainInferenceConfig = field(default_factory=DomainInferenceConfig)
def load(path: str | None) -> Config: def load(path: str | None) -> Config:
@ -118,8 +124,20 @@ def validate(cfg: Config) -> None:
cfg.vocabulary.replacements = _validate_replacements(cfg.vocabulary.replacements) cfg.vocabulary.replacements = _validate_replacements(cfg.vocabulary.replacements)
cfg.vocabulary.terms = _validate_terms(cfg.vocabulary.terms) cfg.vocabulary.terms = _validate_terms(cfg.vocabulary.terms)
if not isinstance(cfg.domain_inference.enabled, bool):
raise ValueError("domain_inference.enabled must be boolean")
def _from_dict(data: dict[str, Any], cfg: Config) -> Config: def _from_dict(data: dict[str, Any], cfg: Config) -> Config:
if "logging" in data:
raise ValueError("logging section is no longer supported; use -v/--verbose")
if "log_transcript" in data:
raise ValueError("log_transcript is no longer supported; use -v/--verbose")
if "ai" in data:
raise ValueError("ai section is no longer supported")
if "ai_enabled" in data:
raise ValueError("ai_enabled is no longer supported")
has_sections = any( has_sections = any(
key in data key in data
for key in ( for key in (
@ -128,6 +146,7 @@ def _from_dict(data: dict[str, Any], cfg: Config) -> Config:
"stt", "stt",
"injection", "injection",
"vocabulary", "vocabulary",
"domain_inference",
) )
) )
if has_sections: if has_sections:
@ -136,6 +155,7 @@ def _from_dict(data: dict[str, Any], cfg: Config) -> Config:
stt = _ensure_dict(data.get("stt"), "stt") stt = _ensure_dict(data.get("stt"), "stt")
injection = _ensure_dict(data.get("injection"), "injection") injection = _ensure_dict(data.get("injection"), "injection")
vocabulary = _ensure_dict(data.get("vocabulary"), "vocabulary") vocabulary = _ensure_dict(data.get("vocabulary"), "vocabulary")
domain_inference = _ensure_dict(data.get("domain_inference"), "domain_inference")
if "hotkey" in daemon: if "hotkey" in daemon:
cfg.daemon.hotkey = _as_nonempty_str(daemon["hotkey"], "daemon.hotkey") cfg.daemon.hotkey = _as_nonempty_str(daemon["hotkey"], "daemon.hotkey")
@ -156,7 +176,28 @@ def _from_dict(data: dict[str, Any], cfg: Config) -> Config:
cfg.vocabulary.replacements = _as_replacements(vocabulary["replacements"]) cfg.vocabulary.replacements = _as_replacements(vocabulary["replacements"])
if "terms" in vocabulary: if "terms" in vocabulary:
cfg.vocabulary.terms = _as_terms(vocabulary["terms"]) cfg.vocabulary.terms = _as_terms(vocabulary["terms"])
if "max_rules" in vocabulary:
raise ValueError("vocabulary.max_rules is no longer supported")
if "max_terms" in vocabulary:
raise ValueError("vocabulary.max_terms is no longer supported")
if "enabled" in domain_inference:
cfg.domain_inference.enabled = _as_bool(
domain_inference["enabled"], "domain_inference.enabled"
)
if "mode" in domain_inference:
raise ValueError("domain_inference.mode is no longer supported")
return cfg return cfg
if "hotkey" in data:
cfg.daemon.hotkey = _as_nonempty_str(data["hotkey"], "hotkey")
if "input" in data:
cfg.recording.input = _as_recording_input(data["input"])
if "whisper_model" in data:
cfg.stt.model = _as_nonempty_str(data["whisper_model"], "whisper_model")
if "whisper_device" in data:
cfg.stt.device = _as_nonempty_str(data["whisper_device"], "whisper_device")
if "injection_backend" in data:
cfg.injection.backend = _as_nonempty_str(data["injection_backend"], "injection_backend")
return cfg return cfg

View file

@ -2,7 +2,6 @@ from pathlib import Path
DEFAULT_CONFIG_PATH = Path.home() / ".config" / "aman" / "config.json" DEFAULT_CONFIG_PATH = Path.home() / ".config" / "aman" / "config.json"
CONFIG_RELOAD_POLL_INTERVAL_SEC = 1.0
RECORD_TIMEOUT_SEC = 300 RECORD_TIMEOUT_SEC = 300
STT_LANGUAGE = "en" STT_LANGUAGE = "en"
TRAY_UPDATE_MS = 250 TRAY_UPDATE_MS = 250

View file

@ -5,7 +5,7 @@ from typing import Callable, Protocol
class DesktopAdapter(Protocol): class DesktopAdapter(Protocol):
def set_hotkeys(self, bindings: dict[str, Callable[[], None]]) -> None: def start_hotkey_listener(self, hotkey: str, callback: Callable[[], None]) -> None:
raise NotImplementedError raise NotImplementedError
def start_cancel_listener(self, callback: Callable[[], None]) -> None: def start_cancel_listener(self, callback: Callable[[], None]) -> None:

View file

@ -4,7 +4,7 @@ from typing import Callable
class WaylandAdapter: class WaylandAdapter:
def set_hotkeys(self, _bindings: dict[str, Callable[[], None]]) -> None: def start_hotkey_listener(self, _hotkey: str, _callback: Callable[[], None]) -> None:
raise SystemExit("Wayland hotkeys are not supported yet.") raise SystemExit("Wayland hotkeys are not supported yet.")
def start_cancel_listener(self, _callback: Callable[[], None]) -> None: def start_cancel_listener(self, _callback: Callable[[], None]) -> None:

View file

@ -4,7 +4,7 @@ import logging
import threading import threading
import time import time
import warnings import warnings
from typing import Any, Callable, Iterable from typing import Callable, Iterable
import gi import gi
from Xlib import X, XK, display from Xlib import X, XK, display
@ -43,8 +43,6 @@ class X11Adapter:
self.indicator = None self.indicator = None
self.status_icon = None self.status_icon = None
self.menu = None self.menu = None
self._hotkey_listener_lock = threading.Lock()
self._hotkey_listeners: dict[str, dict[str, Any]] = {}
if AppIndicator3 is not None: if AppIndicator3 is not None:
self.indicator = AppIndicator3.Indicator.new( self.indicator = AppIndicator3.Indicator.new(
"aman", "aman",
@ -67,29 +65,15 @@ class X11Adapter:
if self.menu: if self.menu:
self.menu.popup(None, None, None, None, 0, _time) self.menu.popup(None, None, None, None, 0, _time)
def set_hotkeys(self, bindings: dict[str, Callable[[], None]]) -> None: def start_hotkey_listener(self, hotkey: str, callback: Callable[[], None]) -> None:
if not isinstance(bindings, dict): mods, keysym = self._parse_hotkey(hotkey)
raise ValueError("bindings must be a dictionary") self._validate_hotkey_registration(mods, keysym)
next_listeners: dict[str, dict[str, Any]] = {} thread = threading.Thread(target=self._listen, args=(mods, keysym, callback), daemon=True)
try: thread.start()
for hotkey, callback in bindings.items():
if not callable(callback):
raise ValueError(f"callback for hotkey {hotkey} must be callable")
next_listeners[hotkey] = self._spawn_hotkey_listener(hotkey, callback)
except Exception:
for listener in next_listeners.values():
self._stop_hotkey_listener(listener)
raise
with self._hotkey_listener_lock:
previous = self._hotkey_listeners
self._hotkey_listeners = next_listeners
for listener in previous.values():
self._stop_hotkey_listener(listener)
def start_cancel_listener(self, callback: Callable[[], None]) -> None: def start_cancel_listener(self, callback: Callable[[], None]) -> None:
mods, keysym = self._parse_hotkey("Escape") mods, keysym = self._parse_hotkey("Escape")
thread = threading.Thread(target=self._listen, args=(mods, keysym, callback, threading.Event()), daemon=True) thread = threading.Thread(target=self._listen, args=(mods, keysym, callback), daemon=True)
thread.start() thread.start()
def inject_text( def inject_text(
@ -143,14 +127,7 @@ class X11Adapter:
finally: finally:
self.request_quit() self.request_quit()
def _listen( def _listen(self, mods: int, keysym: int, callback: Callable[[], None]) -> None:
self,
mods: int,
keysym: int,
callback: Callable[[], None],
stop_event: threading.Event,
listener_meta: dict[str, Any] | None = None,
) -> None:
disp = None disp = None
root = None root = None
keycode = None keycode = None
@ -158,26 +135,14 @@ class X11Adapter:
disp = display.Display() disp = display.Display()
root = disp.screen().root root = disp.screen().root
keycode = self._grab_hotkey(disp, root, mods, keysym) keycode = self._grab_hotkey(disp, root, mods, keysym)
if listener_meta is not None: while True:
listener_meta["display"] = disp
listener_meta["root"] = root
listener_meta["keycode"] = keycode
listener_meta["ready"].set()
while not stop_event.is_set():
if disp.pending_events() == 0:
time.sleep(0.05)
continue
ev = disp.next_event() ev = disp.next_event()
if ev.type == X.KeyPress and ev.detail == keycode: if ev.type == X.KeyPress and ev.detail == keycode:
state = ev.state & ~(X.LockMask | X.Mod2Mask) state = ev.state & ~(X.LockMask | X.Mod2Mask)
if state == mods: if state == mods:
callback() callback()
except Exception as exc: except Exception as exc:
if listener_meta is not None: logging.error("hotkey listener stopped: %s", exc)
listener_meta["error"] = exc
listener_meta["ready"].set()
if not stop_event.is_set():
logging.error("hotkey listener stopped: %s", exc)
finally: finally:
if root is not None and keycode is not None and disp is not None: if root is not None and keycode is not None and disp is not None:
try: try:
@ -185,13 +150,6 @@ class X11Adapter:
disp.sync() disp.sync()
except Exception: except Exception:
pass pass
if disp is not None:
try:
disp.close()
except Exception:
pass
if listener_meta is not None:
listener_meta["ready"].set()
def _parse_hotkey(self, hotkey: str): def _parse_hotkey(self, hotkey: str):
mods = 0 mods = 0
@ -227,51 +185,6 @@ class X11Adapter:
except Exception: except Exception:
pass pass
def _spawn_hotkey_listener(self, hotkey: str, callback: Callable[[], None]) -> dict[str, Any]:
mods, keysym = self._parse_hotkey(hotkey)
self._validate_hotkey_registration(mods, keysym)
stop_event = threading.Event()
listener_meta: dict[str, Any] = {
"hotkey": hotkey,
"mods": mods,
"keysym": keysym,
"stop_event": stop_event,
"display": None,
"root": None,
"keycode": None,
"error": None,
"ready": threading.Event(),
}
thread = threading.Thread(
target=self._listen,
args=(mods, keysym, callback, stop_event, listener_meta),
daemon=True,
)
listener_meta["thread"] = thread
thread.start()
if not listener_meta["ready"].wait(timeout=2.0):
stop_event.set()
raise RuntimeError("hotkey listener setup timed out")
if listener_meta["error"] is not None:
stop_event.set()
raise listener_meta["error"]
if listener_meta["keycode"] is None:
stop_event.set()
raise RuntimeError("hotkey listener failed to initialize")
return listener_meta
def _stop_hotkey_listener(self, listener: dict[str, Any]) -> None:
listener["stop_event"].set()
disp = listener.get("display")
if disp is not None:
try:
disp.close()
except Exception:
pass
thread = listener.get("thread")
if thread is not None:
thread.join(timeout=1.0)
def _grab_hotkey(self, disp, root, mods, keysym): def _grab_hotkey(self, disp, root, mods, keysym):
keycode = disp.keysym_to_keycode(keysym) keycode = disp.keysym_to_keycode(keysym)
if keycode == 0: if keycode == 0:

View file

@ -1,86 +0,0 @@
from __future__ import annotations
import logging
from dataclasses import dataclass
from typing import Any, Callable
ALLOWED_FAILURE_POLICIES = {"best_effort", "strict"}
@dataclass(frozen=True)
class PipelineOptions:
failure_policy: str = "best_effort"
def __post_init__(self):
policy = (self.failure_policy or "").strip().lower()
if policy not in ALLOWED_FAILURE_POLICIES:
allowed = ", ".join(sorted(ALLOWED_FAILURE_POLICIES))
raise ValueError(f"failure_policy must be one of: {allowed}")
object.__setattr__(self, "failure_policy", policy)
@dataclass(frozen=True)
class PipelineBinding:
hotkey: str
handler: Callable[[Any, "PipelineLib"], str]
options: PipelineOptions
class PipelineLib:
def __init__(
self,
*,
transcribe_fn: Callable[[Any], str],
llm_fn: Callable[..., str],
):
self._transcribe_fn = transcribe_fn
self._llm_fn = llm_fn
def transcribe(
self,
audio: Any,
*,
hints: list[str] | None = None,
whisper_opts: dict[str, Any] | None = None,
) -> str:
return self._transcribe_fn(
audio,
hints=hints,
whisper_opts=whisper_opts,
)
def llm(
self,
*,
system_prompt: str,
user_prompt: str,
llm_opts: dict[str, Any] | None = None,
) -> str:
return self._llm_fn(
system_prompt=system_prompt,
user_prompt=user_prompt,
llm_opts=llm_opts,
)
class Engine:
def __init__(self, lib: PipelineLib):
self.lib = lib
def run(self, binding: PipelineBinding, audio: Any) -> str:
try:
output = binding.handler(audio, self.lib)
except Exception as exc:
if binding.options.failure_policy == "strict":
raise
logging.error("pipeline failed for hotkey %s: %s", binding.hotkey, exc)
return ""
if output is None:
return ""
if not isinstance(output, str):
raise RuntimeError(
f"pipeline for hotkey {binding.hotkey} returned non-string output"
)
return output

View file

@ -1,40 +0,0 @@
from __future__ import annotations
class DesktopNotifier:
def __init__(self, app_name: str = "aman"):
self._app_name = app_name
self._notify = None
self._ready = False
self._attempted = False
def send(self, title: str, body: str, *, error: bool = False) -> bool:
_ = error
notify = self._ensure_backend()
if notify is None:
return False
try:
notification = notify.Notification.new(title, body, None)
notification.show()
return True
except Exception:
return False
def _ensure_backend(self):
if self._attempted:
return self._notify if self._ready else None
self._attempted = True
try:
import gi
gi.require_version("Notify", "0.7")
from gi.repository import Notify # type: ignore[import-not-found]
if Notify.init(self._app_name):
self._notify = Notify
self._ready = True
except Exception:
self._notify = None
self._ready = False
return self._notify if self._ready else None

View file

@ -1,97 +0,0 @@
from __future__ import annotations
import hashlib
import importlib.util
from pathlib import Path
from types import ModuleType
from typing import Any, Callable
from engine import PipelineBinding, PipelineOptions
DEFAULT_PIPELINES_PATH = Path.home() / ".config" / "aman" / "pipelines.py"
def fingerprint(path: Path) -> str | None:
try:
data = path.read_bytes()
except FileNotFoundError:
return None
return hashlib.sha256(data).hexdigest()
def load_bindings(
*,
path: Path,
default_hotkey: str,
default_handler_factory: Callable[[], Callable[[Any, Any], str]],
) -> dict[str, PipelineBinding]:
if not path.exists():
handler = default_handler_factory()
options = PipelineOptions(failure_policy="best_effort")
return {
default_hotkey: PipelineBinding(
hotkey=default_hotkey,
handler=handler,
options=options,
)
}
module = _load_module(path)
return _bindings_from_module(module)
def _load_module(path: Path) -> ModuleType:
module_name = f"aman_user_pipelines_{hash(path.resolve())}"
spec = importlib.util.spec_from_file_location(module_name, str(path))
if spec is None or spec.loader is None:
raise ValueError(f"unable to load pipelines module: {path}")
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module
def _bindings_from_module(module: ModuleType) -> dict[str, PipelineBinding]:
raw_pipelines = getattr(module, "HOTKEY_PIPELINES", None)
if not isinstance(raw_pipelines, dict):
raise ValueError("HOTKEY_PIPELINES must be a dictionary")
if not raw_pipelines:
raise ValueError("HOTKEY_PIPELINES cannot be empty")
raw_options = getattr(module, "PIPELINE_OPTIONS", {})
if raw_options is None:
raw_options = {}
if not isinstance(raw_options, dict):
raise ValueError("PIPELINE_OPTIONS must be a dictionary when provided")
bindings: dict[str, PipelineBinding] = {}
for key, handler in raw_pipelines.items():
hotkey = _validate_hotkey(key)
if not callable(handler):
raise ValueError(f"HOTKEY_PIPELINES[{hotkey}] must be callable")
option_data = raw_options.get(hotkey, {})
options = _parse_options(hotkey, option_data)
bindings[hotkey] = PipelineBinding(
hotkey=hotkey,
handler=handler,
options=options,
)
return bindings
def _validate_hotkey(value: Any) -> str:
if not isinstance(value, str):
raise ValueError("pipeline hotkey keys must be strings")
hotkey = value.strip()
if not hotkey:
raise ValueError("pipeline hotkey keys cannot be empty")
return hotkey
def _parse_options(hotkey: str, value: Any) -> PipelineOptions:
if value is None:
return PipelineOptions()
if not isinstance(value, dict):
raise ValueError(f"PIPELINE_OPTIONS[{hotkey}] must be an object")
failure_policy = value.get("failure_policy", "best_effort")
return PipelineOptions(failure_policy=failure_policy)

View file

@ -4,7 +4,101 @@ import re
from dataclasses import dataclass from dataclasses import dataclass
from typing import Iterable from typing import Iterable
from config import VocabularyConfig from config import DomainInferenceConfig, VocabularyConfig
DOMAIN_GENERAL = "general"
DOMAIN_PERSONAL_NAMES = "personal_names"
DOMAIN_SOFTWARE_DEV = "software_dev"
DOMAIN_OPS_INFRA = "ops_infra"
DOMAIN_BUSINESS = "business"
DOMAIN_MEDICAL_LEGAL = "medical_legal"
DOMAIN_ORDER = (
DOMAIN_PERSONAL_NAMES,
DOMAIN_SOFTWARE_DEV,
DOMAIN_OPS_INFRA,
DOMAIN_BUSINESS,
DOMAIN_MEDICAL_LEGAL,
)
DOMAIN_KEYWORDS = {
DOMAIN_SOFTWARE_DEV: {
"api",
"bug",
"code",
"commit",
"docker",
"function",
"git",
"github",
"javascript",
"python",
"refactor",
"repository",
"typescript",
"unit",
"test",
},
DOMAIN_OPS_INFRA: {
"cluster",
"container",
"deploy",
"deployment",
"incident",
"kubernetes",
"monitoring",
"nginx",
"pod",
"prod",
"service",
"systemd",
"terraform",
},
DOMAIN_BUSINESS: {
"budget",
"client",
"deadline",
"finance",
"invoice",
"meeting",
"milestone",
"project",
"quarter",
"roadmap",
"sales",
"stakeholder",
},
DOMAIN_MEDICAL_LEGAL: {
"agreement",
"case",
"claim",
"compliance",
"contract",
"diagnosis",
"liability",
"patient",
"prescription",
"regulation",
"symptom",
"treatment",
},
}
DOMAIN_PHRASES = {
DOMAIN_SOFTWARE_DEV: ("pull request", "code review", "integration test"),
DOMAIN_OPS_INFRA: ("on call", "service restart", "roll back"),
DOMAIN_BUSINESS: ("follow up", "action items", "meeting notes"),
DOMAIN_MEDICAL_LEGAL: ("terms and conditions", "medical record", "legal review"),
}
GREETING_TOKENS = {"hello", "hi", "hey", "good morning", "good afternoon", "good evening"}
@dataclass(frozen=True)
class DomainResult:
name: str
confidence: float
@dataclass(frozen=True) @dataclass(frozen=True)
@ -14,9 +108,10 @@ class _ReplacementView:
class VocabularyEngine: class VocabularyEngine:
def __init__(self, vocab_cfg: VocabularyConfig): def __init__(self, vocab_cfg: VocabularyConfig, domain_cfg: DomainInferenceConfig):
self._replacements = [_ReplacementView(r.source, r.target) for r in vocab_cfg.replacements] self._replacements = [_ReplacementView(r.source, r.target) for r in vocab_cfg.replacements]
self._terms = list(vocab_cfg.terms) self._terms = list(vocab_cfg.terms)
self._domain_enabled = bool(domain_cfg.enabled)
self._replacement_map = { self._replacement_map = {
_normalize_key(rule.source): rule.target for rule in self._replacements _normalize_key(rule.source): rule.target for rule in self._replacements
@ -66,6 +161,55 @@ class VocabularyEngine:
used += addition used += addition
return "\n".join(out) return "\n".join(out)
def infer_domain(self, text: str) -> DomainResult:
if not self._domain_enabled:
return DomainResult(name=DOMAIN_GENERAL, confidence=0.0)
normalized = text.casefold()
tokens = re.findall(r"[a-z0-9+#./_-]+", normalized)
if not tokens:
return DomainResult(name=DOMAIN_GENERAL, confidence=0.0)
scores = {domain: 0 for domain in DOMAIN_ORDER}
for token in tokens:
for domain, keywords in DOMAIN_KEYWORDS.items():
if token in keywords:
scores[domain] += 2
for domain, phrases in DOMAIN_PHRASES.items():
for phrase in phrases:
if phrase in normalized:
scores[domain] += 2
if any(token in GREETING_TOKENS for token in tokens):
scores[DOMAIN_PERSONAL_NAMES] += 1
# Boost domains from configured dictionary terms and replacement targets.
dictionary_tokens = self._dictionary_tokens()
for token in dictionary_tokens:
for domain, keywords in DOMAIN_KEYWORDS.items():
if token in keywords and token in tokens:
scores[domain] += 1
top_domain = DOMAIN_GENERAL
top_score = 0
total_score = 0
for domain in DOMAIN_ORDER:
score = scores[domain]
total_score += score
if score > top_score:
top_score = score
top_domain = domain
if top_score < 2 or total_score == 0:
return DomainResult(name=DOMAIN_GENERAL, confidence=0.0)
confidence = top_score / total_score
if confidence < 0.45:
return DomainResult(name=DOMAIN_GENERAL, confidence=0.0)
return DomainResult(name=top_domain, confidence=round(confidence, 2))
def _build_stt_hotwords(self, *, limit: int, char_budget: int) -> str: def _build_stt_hotwords(self, *, limit: int, char_budget: int) -> str:
items = _dedupe_preserve_order( items = _dedupe_preserve_order(
[rule.target for rule in self._replacements] + self._terms [rule.target for rule in self._replacements] + self._terms
@ -92,6 +236,19 @@ class VocabularyEngine:
return "" return ""
return prefix + hotwords return prefix + hotwords
def _dictionary_tokens(self) -> set[str]:
values: list[str] = []
for rule in self._replacements:
values.append(rule.source)
values.append(rule.target)
values.extend(self._terms)
tokens: set[str] = set()
for value in values:
for token in re.findall(r"[a-z0-9+#./_-]+", value.casefold()):
tokens.add(token)
return tokens
def _build_replacement_pattern(sources: Iterable[str]) -> re.Pattern[str] | None: def _build_replacement_pattern(sources: Iterable[str]) -> re.Pattern[str] | None:
unique_sources = _dedupe_preserve_order(list(sources)) unique_sources = _dedupe_preserve_order(list(sources))

View file

@ -1,4 +1,3 @@
import json
import os import os
import sys import sys
import tempfile import tempfile
@ -12,25 +11,14 @@ if str(SRC) not in sys.path:
sys.path.insert(0, str(SRC)) sys.path.insert(0, str(SRC))
import aman import aman
from config import Config, VocabularyReplacement, redacted_dict from config import Config, VocabularyReplacement
from engine import PipelineBinding, PipelineOptions
class FakeDesktop: class FakeDesktop:
def __init__(self): def __init__(self):
self.inject_calls = [] self.inject_calls = []
self.hotkey_updates = []
self.hotkeys = {}
self.cancel_callback = None
self.quit_calls = 0 self.quit_calls = 0
def set_hotkeys(self, bindings):
self.hotkeys = dict(bindings)
self.hotkey_updates.append(tuple(sorted(bindings.keys())))
def start_cancel_listener(self, callback):
self.cancel_callback = callback
def inject_text( def inject_text(
self, self,
text: str, text: str,
@ -88,30 +76,12 @@ class FakeAIProcessor:
def process(self, text, lang="en", **_kwargs): def process(self, text, lang="en", **_kwargs):
return text return text
def chat(self, *, system_prompt, user_prompt, llm_opts=None):
_ = system_prompt
opts = llm_opts or {}
if "response_format" in opts:
payload = json.loads(user_prompt)
transcript = payload.get("transcript", "")
return json.dumps({"cleaned_text": transcript})
return "general"
class FakeAudio: class FakeAudio:
def __init__(self, size: int): def __init__(self, size: int):
self.size = size self.size = size
class FakeNotifier:
def __init__(self):
self.events = []
def send(self, title, body, *, error=False):
self.events.append((title, body, error))
return True
class DaemonTests(unittest.TestCase): class DaemonTests(unittest.TestCase):
def _config(self) -> Config: def _config(self) -> Config:
cfg = Config() cfg = Config()
@ -133,23 +103,19 @@ class DaemonTests(unittest.TestCase):
@patch("aman.stop_audio_recording", return_value=FakeAudio(8)) @patch("aman.stop_audio_recording", return_value=FakeAudio(8))
@patch("aman.start_audio_recording", return_value=(object(), object())) @patch("aman.start_audio_recording", return_value=(object(), object()))
def test_hotkey_start_stop_injects_text(self, _start_mock, _stop_mock): def test_toggle_start_stop_injects_text(self, _start_mock, _stop_mock):
desktop = FakeDesktop() desktop = FakeDesktop()
daemon = self._build_daemon(desktop, FakeModel(), verbose=False) daemon = self._build_daemon(desktop, FakeModel(), verbose=False)
daemon._start_stop_worker = ( daemon._start_stop_worker = (
lambda stream, record, trigger, process_audio: daemon._stop_and_process( lambda stream, record, trigger, process_audio: daemon._stop_and_process(
stream, stream, record, trigger, process_audio
record,
trigger,
process_audio,
daemon._pending_worker_hotkey,
) )
) )
daemon.handle_hotkey(daemon.cfg.daemon.hotkey) daemon.toggle()
self.assertEqual(daemon.get_state(), aman.State.RECORDING) self.assertEqual(daemon.get_state(), aman.State.RECORDING)
daemon.handle_hotkey(daemon.cfg.daemon.hotkey) daemon.toggle()
self.assertEqual(daemon.get_state(), aman.State.IDLE) self.assertEqual(daemon.get_state(), aman.State.IDLE)
self.assertEqual(desktop.inject_calls, [("hello world", "clipboard", False)]) self.assertEqual(desktop.inject_calls, [("hello world", "clipboard", False)])
@ -165,7 +131,7 @@ class DaemonTests(unittest.TestCase):
) )
) )
daemon.handle_hotkey(daemon.cfg.daemon.hotkey) daemon.toggle()
self.assertEqual(daemon.get_state(), aman.State.RECORDING) self.assertEqual(daemon.get_state(), aman.State.RECORDING)
self.assertTrue(daemon.shutdown(timeout=0.2)) self.assertTrue(daemon.shutdown(timeout=0.2))
@ -187,8 +153,8 @@ class DaemonTests(unittest.TestCase):
) )
) )
daemon.handle_hotkey(daemon.cfg.daemon.hotkey) daemon.toggle()
daemon.handle_hotkey(daemon.cfg.daemon.hotkey) daemon.toggle()
self.assertEqual(desktop.inject_calls, [("good morning Marta", "clipboard", False)]) self.assertEqual(desktop.inject_calls, [("good morning Marta", "clipboard", False)])
@ -257,8 +223,8 @@ class DaemonTests(unittest.TestCase):
) )
) )
daemon.handle_hotkey(daemon.cfg.daemon.hotkey) daemon.toggle()
daemon.handle_hotkey(daemon.cfg.daemon.hotkey) daemon.toggle()
self.assertEqual(desktop.inject_calls, [("hello world", "clipboard", True)]) self.assertEqual(desktop.inject_calls, [("hello world", "clipboard", True)])
@ -273,161 +239,6 @@ class DaemonTests(unittest.TestCase):
any("DEBUG:root:state: idle -> recording" in line for line in logs.output) any("DEBUG:root:state: idle -> recording" in line for line in logs.output)
) )
def test_hotkey_dispatches_to_matching_pipeline(self):
desktop = FakeDesktop()
daemon = self._build_daemon(desktop, FakeModel(), verbose=False)
daemon.set_pipeline_bindings(
{
"Super+m": PipelineBinding(
hotkey="Super+m",
handler=lambda _audio, _lib: "default",
options=PipelineOptions(failure_policy="best_effort"),
),
"Super+Shift+m": PipelineBinding(
hotkey="Super+Shift+m",
handler=lambda _audio, _lib: "caps",
options=PipelineOptions(failure_policy="best_effort"),
),
}
)
out = daemon._run_pipeline(object(), "Super+Shift+m")
self.assertEqual(out, "caps")
def test_try_apply_config_applies_new_runtime_settings(self):
desktop = FakeDesktop()
cfg = self._config()
daemon = self._build_daemon(desktop, FakeModel(), cfg=cfg, verbose=False)
daemon._stt_hint_kwargs_cache = {"hotwords": "old"}
candidate = Config()
candidate.injection.backend = "injection"
candidate.vocabulary.replacements = [
VocabularyReplacement(source="Martha", target="Marta"),
]
status, changed, error = daemon.try_apply_config(candidate)
self.assertEqual(status, "applied")
self.assertEqual(error, "")
self.assertIn("injection.backend", changed)
self.assertIn("vocabulary.replacements", changed)
self.assertEqual(daemon.cfg.injection.backend, "injection")
self.assertEqual(
daemon.vocabulary.apply_deterministic_replacements("Martha"),
"Marta",
)
self.assertIsNone(daemon._stt_hint_kwargs_cache)
def test_try_apply_config_reloads_stt_model(self):
desktop = FakeDesktop()
cfg = self._config()
daemon = self._build_daemon(desktop, FakeModel(), cfg=cfg, verbose=False)
candidate = Config()
candidate.stt.model = "small"
next_model = FakeModel(text="from-new-model")
with patch("aman._build_whisper_model", return_value=next_model):
status, changed, error = daemon.try_apply_config(candidate)
self.assertEqual(status, "applied")
self.assertEqual(error, "")
self.assertIn("stt.model", changed)
self.assertIs(daemon.model, next_model)
def test_try_apply_config_is_deferred_while_busy(self):
desktop = FakeDesktop()
daemon = self._build_daemon(desktop, FakeModel(), verbose=False)
daemon.set_state(aman.State.RECORDING)
status, changed, error = daemon.try_apply_config(Config())
self.assertEqual(status, "deferred")
self.assertEqual(changed, [])
self.assertIn("busy", error)
class ConfigReloaderTests(unittest.TestCase):
def _write_config(self, path: Path, cfg: Config):
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(json.dumps(redacted_dict(cfg)), encoding="utf-8")
def _daemon_for_path(self, path: Path) -> tuple[aman.Daemon, FakeDesktop]:
cfg = Config()
self._write_config(path, cfg)
desktop = FakeDesktop()
with patch("aman._build_whisper_model", return_value=FakeModel()), patch(
"aman.LlamaProcessor", return_value=FakeAIProcessor()
):
daemon = aman.Daemon(cfg, desktop, verbose=False)
return daemon, desktop
def test_reloader_applies_changed_config(self):
with tempfile.TemporaryDirectory() as td:
path = Path(td) / "config.json"
daemon, _desktop = self._daemon_for_path(path)
notifier = FakeNotifier()
reloader = aman.ConfigReloader(
daemon=daemon,
config_path=path,
notifier=notifier,
poll_interval_sec=0.01,
)
updated = Config()
updated.injection.backend = "injection"
self._write_config(path, updated)
reloader.tick()
self.assertEqual(daemon.cfg.injection.backend, "injection")
self.assertTrue(any(evt[0] == "Config Reloaded" and not evt[2] for evt in notifier.events))
def test_reloader_keeps_last_good_config_when_invalid(self):
with tempfile.TemporaryDirectory() as td:
path = Path(td) / "config.json"
daemon, _desktop = self._daemon_for_path(path)
notifier = FakeNotifier()
reloader = aman.ConfigReloader(
daemon=daemon,
config_path=path,
notifier=notifier,
poll_interval_sec=0.01,
)
path.write_text('{"injection":{"backend":"invalid"}}', encoding="utf-8")
reloader.tick()
self.assertEqual(daemon.cfg.injection.backend, "clipboard")
fail_events = [evt for evt in notifier.events if evt[0] == "Config Reload Failed"]
self.assertEqual(len(fail_events), 1)
reloader.tick()
fail_events = [evt for evt in notifier.events if evt[0] == "Config Reload Failed"]
self.assertEqual(len(fail_events), 1)
def test_reloader_defers_apply_until_idle(self):
with tempfile.TemporaryDirectory() as td:
path = Path(td) / "config.json"
daemon, _desktop = self._daemon_for_path(path)
notifier = FakeNotifier()
reloader = aman.ConfigReloader(
daemon=daemon,
config_path=path,
notifier=notifier,
poll_interval_sec=0.01,
)
updated = Config()
updated.injection.backend = "injection"
self._write_config(path, updated)
daemon.set_state(aman.State.RECORDING)
reloader.tick()
self.assertEqual(daemon.cfg.injection.backend, "clipboard")
daemon.set_state(aman.State.IDLE)
reloader.tick()
self.assertEqual(daemon.cfg.injection.backend, "injection")
class LockTests(unittest.TestCase): class LockTests(unittest.TestCase):
def test_lock_rejects_second_instance(self): def test_lock_rejects_second_instance(self):

View file

@ -26,6 +26,7 @@ class ConfigTests(unittest.TestCase):
self.assertFalse(cfg.injection.remove_transcription_from_clipboard) self.assertFalse(cfg.injection.remove_transcription_from_clipboard)
self.assertEqual(cfg.vocabulary.replacements, []) self.assertEqual(cfg.vocabulary.replacements, [])
self.assertEqual(cfg.vocabulary.terms, []) self.assertEqual(cfg.vocabulary.terms, [])
self.assertTrue(cfg.domain_inference.enabled)
self.assertTrue(missing.exists()) self.assertTrue(missing.exists())
written = json.loads(missing.read_text(encoding="utf-8")) written = json.loads(missing.read_text(encoding="utf-8"))
@ -47,6 +48,7 @@ class ConfigTests(unittest.TestCase):
], ],
"terms": ["Systemd", "Kubernetes"], "terms": ["Systemd", "Kubernetes"],
}, },
"domain_inference": {"enabled": True},
} }
with tempfile.TemporaryDirectory() as td: with tempfile.TemporaryDirectory() as td:
path = Path(td) / "config.json" path = Path(td) / "config.json"
@ -64,6 +66,7 @@ class ConfigTests(unittest.TestCase):
self.assertEqual(cfg.vocabulary.replacements[0].source, "Martha") self.assertEqual(cfg.vocabulary.replacements[0].source, "Martha")
self.assertEqual(cfg.vocabulary.replacements[0].target, "Marta") self.assertEqual(cfg.vocabulary.replacements[0].target, "Marta")
self.assertEqual(cfg.vocabulary.terms, ["Systemd", "Kubernetes"]) self.assertEqual(cfg.vocabulary.terms, ["Systemd", "Kubernetes"])
self.assertTrue(cfg.domain_inference.enabled)
def test_super_modifier_hotkey_is_valid(self): def test_super_modifier_hotkey_is_valid(self):
payload = {"daemon": {"hotkey": "Super+m"}} payload = {"daemon": {"hotkey": "Super+m"}}
@ -95,6 +98,28 @@ class ConfigTests(unittest.TestCase):
): ):
load(str(path)) load(str(path))
def test_loads_legacy_keys(self):
payload = {
"hotkey": "Alt+m",
"input": "Mic",
"whisper_model": "tiny",
"whisper_device": "cpu",
"injection_backend": "clipboard",
}
with tempfile.TemporaryDirectory() as td:
path = Path(td) / "config.json"
path.write_text(json.dumps(payload), encoding="utf-8")
cfg = load(str(path))
self.assertEqual(cfg.daemon.hotkey, "Alt+m")
self.assertEqual(cfg.recording.input, "Mic")
self.assertEqual(cfg.stt.model, "tiny")
self.assertEqual(cfg.stt.device, "cpu")
self.assertEqual(cfg.injection.backend, "clipboard")
self.assertFalse(cfg.injection.remove_transcription_from_clipboard)
self.assertEqual(cfg.vocabulary.replacements, [])
def test_invalid_injection_backend_raises(self): def test_invalid_injection_backend_raises(self):
payload = {"injection": {"backend": "invalid"}} payload = {"injection": {"backend": "invalid"}}
with tempfile.TemporaryDirectory() as td: with tempfile.TemporaryDirectory() as td:
@ -113,20 +138,41 @@ class ConfigTests(unittest.TestCase):
with self.assertRaisesRegex(ValueError, "injection.remove_transcription_from_clipboard"): with self.assertRaisesRegex(ValueError, "injection.remove_transcription_from_clipboard"):
load(str(path)) load(str(path))
def test_unknown_top_level_fields_are_ignored(self): def test_removed_ai_section_raises(self):
payload = { payload = {"ai": {"enabled": True}}
"custom_a": {"enabled": True},
"custom_b": {"nested": "value"},
"custom_c": 123,
}
with tempfile.TemporaryDirectory() as td: with tempfile.TemporaryDirectory() as td:
path = Path(td) / "config.json" path = Path(td) / "config.json"
path.write_text(json.dumps(payload), encoding="utf-8") path.write_text(json.dumps(payload), encoding="utf-8")
cfg = load(str(path)) with self.assertRaisesRegex(ValueError, "ai section is no longer supported"):
load(str(path))
self.assertEqual(cfg.daemon.hotkey, "Cmd+m") def test_removed_legacy_ai_enabled_raises(self):
self.assertEqual(cfg.injection.backend, "clipboard") payload = {"ai_enabled": True}
with tempfile.TemporaryDirectory() as td:
path = Path(td) / "config.json"
path.write_text(json.dumps(payload), encoding="utf-8")
with self.assertRaisesRegex(ValueError, "ai_enabled is no longer supported"):
load(str(path))
def test_removed_logging_section_raises(self):
payload = {"logging": {"log_transcript": True}}
with tempfile.TemporaryDirectory() as td:
path = Path(td) / "config.json"
path.write_text(json.dumps(payload), encoding="utf-8")
with self.assertRaisesRegex(ValueError, "no longer supported"):
load(str(path))
def test_removed_legacy_log_transcript_raises(self):
payload = {"log_transcript": True}
with tempfile.TemporaryDirectory() as td:
path = Path(td) / "config.json"
path.write_text(json.dumps(payload), encoding="utf-8")
with self.assertRaisesRegex(ValueError, "no longer supported"):
load(str(path))
def test_conflicting_replacements_raise(self): def test_conflicting_replacements_raise(self):
payload = { payload = {
@ -178,15 +224,32 @@ class ConfigTests(unittest.TestCase):
with self.assertRaisesRegex(ValueError, "wildcard"): with self.assertRaisesRegex(ValueError, "wildcard"):
load(str(path)) load(str(path))
def test_unknown_vocabulary_fields_are_ignored(self): def test_removed_domain_mode_raises(self):
payload = {"vocabulary": {"custom_limit": 100, "custom_extra": 200, "terms": ["Docker"]}} payload = {"domain_inference": {"mode": "heuristic"}}
with tempfile.TemporaryDirectory() as td: with tempfile.TemporaryDirectory() as td:
path = Path(td) / "config.json" path = Path(td) / "config.json"
path.write_text(json.dumps(payload), encoding="utf-8") path.write_text(json.dumps(payload), encoding="utf-8")
cfg = load(str(path)) with self.assertRaisesRegex(ValueError, "domain_inference.mode is no longer supported"):
load(str(path))
self.assertEqual(cfg.vocabulary.terms, ["Docker"]) def test_removed_vocabulary_max_rules_raises(self):
payload = {"vocabulary": {"max_rules": 100}}
with tempfile.TemporaryDirectory() as td:
path = Path(td) / "config.json"
path.write_text(json.dumps(payload), encoding="utf-8")
with self.assertRaisesRegex(ValueError, "vocabulary.max_rules is no longer supported"):
load(str(path))
def test_removed_vocabulary_max_terms_raises(self):
payload = {"vocabulary": {"max_terms": 100}}
with tempfile.TemporaryDirectory() as td:
path = Path(td) / "config.json"
path.write_text(json.dumps(payload), encoding="utf-8")
with self.assertRaisesRegex(ValueError, "vocabulary.max_terms is no longer supported"):
load(str(path))
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -1,92 +0,0 @@
import sys
import unittest
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
SRC = ROOT / "src"
if str(SRC) not in sys.path:
sys.path.insert(0, str(SRC))
from engine import Engine, PipelineBinding, PipelineLib, PipelineOptions
class EngineTests(unittest.TestCase):
def test_best_effort_pipeline_failure_returns_empty_string(self):
lib = PipelineLib(
transcribe_fn=lambda *_args, **_kwargs: "text",
llm_fn=lambda **_kwargs: "result",
)
engine = Engine(lib)
def failing_pipeline(_audio, _lib):
raise RuntimeError("boom")
binding = PipelineBinding(
hotkey="Cmd+m",
handler=failing_pipeline,
options=PipelineOptions(failure_policy="best_effort"),
)
out = engine.run(binding, object())
self.assertEqual(out, "")
def test_strict_pipeline_failure_raises(self):
lib = PipelineLib(
transcribe_fn=lambda *_args, **_kwargs: "text",
llm_fn=lambda **_kwargs: "result",
)
engine = Engine(lib)
def failing_pipeline(_audio, _lib):
raise RuntimeError("boom")
binding = PipelineBinding(
hotkey="Cmd+m",
handler=failing_pipeline,
options=PipelineOptions(failure_policy="strict"),
)
with self.assertRaisesRegex(RuntimeError, "boom"):
engine.run(binding, object())
def test_pipeline_lib_forwards_arguments(self):
seen = {}
def transcribe_fn(audio, *, hints=None, whisper_opts=None):
seen["audio"] = audio
seen["hints"] = hints
seen["whisper_opts"] = whisper_opts
return "hello"
def llm_fn(*, system_prompt, user_prompt, llm_opts=None):
seen["system_prompt"] = system_prompt
seen["user_prompt"] = user_prompt
seen["llm_opts"] = llm_opts
return "world"
lib = PipelineLib(
transcribe_fn=transcribe_fn,
llm_fn=llm_fn,
)
audio = object()
self.assertEqual(
lib.transcribe(audio, hints=["Docker"], whisper_opts={"vad_filter": True}),
"hello",
)
self.assertEqual(
lib.llm(
system_prompt="sys",
user_prompt="user",
llm_opts={"temperature": 0.2},
),
"world",
)
self.assertIs(seen["audio"], audio)
self.assertEqual(seen["hints"], ["Docker"])
self.assertEqual(seen["whisper_opts"], {"vad_filter": True})
self.assertEqual(seen["system_prompt"], "sys")
self.assertEqual(seen["user_prompt"], "user")
self.assertEqual(seen["llm_opts"], {"temperature": 0.2})
if __name__ == "__main__":
unittest.main()

View file

@ -1,109 +0,0 @@
import tempfile
import textwrap
import sys
import unittest
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
SRC = ROOT / "src"
if str(SRC) not in sys.path:
sys.path.insert(0, str(SRC))
from pipelines_runtime import fingerprint, load_bindings
class PipelinesRuntimeTests(unittest.TestCase):
def test_missing_file_uses_default_binding(self):
with tempfile.TemporaryDirectory() as td:
path = Path(td) / "pipelines.py"
called = {"count": 0}
def factory():
called["count"] += 1
def handler(_audio, _lib):
return "ok"
return handler
bindings = load_bindings(
path=path,
default_hotkey="Cmd+m",
default_handler_factory=factory,
)
self.assertEqual(list(bindings.keys()), ["Cmd+m"])
self.assertEqual(called["count"], 1)
def test_loads_module_bindings_and_options(self):
with tempfile.TemporaryDirectory() as td:
path = Path(td) / "pipelines.py"
path.write_text(
textwrap.dedent(
"""
def p1(audio, lib):
return "one"
def p2(audio, lib):
return "two"
HOTKEY_PIPELINES = {
"Cmd+m": p1,
"Ctrl+space": p2,
}
PIPELINE_OPTIONS = {
"Ctrl+space": {"failure_policy": "strict"},
}
"""
),
encoding="utf-8",
)
bindings = load_bindings(
path=path,
default_hotkey="Cmd+m",
default_handler_factory=lambda: (lambda _audio, _lib: "default"),
)
self.assertEqual(sorted(bindings.keys()), ["Cmd+m", "Ctrl+space"])
self.assertEqual(bindings["Cmd+m"].options.failure_policy, "best_effort")
self.assertEqual(bindings["Ctrl+space"].options.failure_policy, "strict")
def test_invalid_options_fail(self):
with tempfile.TemporaryDirectory() as td:
path = Path(td) / "pipelines.py"
path.write_text(
textwrap.dedent(
"""
def p(audio, lib):
return "ok"
HOTKEY_PIPELINES = {"Cmd+m": p}
PIPELINE_OPTIONS = {"Cmd+m": {"failure_policy": "invalid"}}
"""
),
encoding="utf-8",
)
with self.assertRaisesRegex(ValueError, "failure_policy"):
load_bindings(
path=path,
default_hotkey="Cmd+m",
default_handler_factory=lambda: (lambda _audio, _lib: "default"),
)
def test_fingerprint_changes_with_content(self):
with tempfile.TemporaryDirectory() as td:
path = Path(td) / "pipelines.py"
self.assertIsNone(fingerprint(path))
path.write_text("HOTKEY_PIPELINES = {}", encoding="utf-8")
first = fingerprint(path)
path.write_text("HOTKEY_PIPELINES = {'Cmd+m': lambda a, l: ''}", encoding="utf-8")
second = fingerprint(path)
self.assertIsNotNone(first)
self.assertIsNotNone(second)
self.assertNotEqual(first, second)
if __name__ == "__main__":
unittest.main()

View file

@ -7,17 +7,18 @@ SRC = ROOT / "src"
if str(SRC) not in sys.path: if str(SRC) not in sys.path:
sys.path.insert(0, str(SRC)) sys.path.insert(0, str(SRC))
from config import VocabularyConfig, VocabularyReplacement from config import DomainInferenceConfig, VocabularyConfig, VocabularyReplacement
from vocabulary import VocabularyEngine from vocabulary import DOMAIN_GENERAL, VocabularyEngine
class VocabularyEngineTests(unittest.TestCase): class VocabularyEngineTests(unittest.TestCase):
def _engine(self, replacements=None, terms=None): def _engine(self, replacements=None, terms=None, domain_enabled=True):
vocab = VocabularyConfig( vocab = VocabularyConfig(
replacements=replacements or [], replacements=replacements or [],
terms=terms or [], terms=terms or [],
) )
return VocabularyEngine(vocab) domain = DomainInferenceConfig(enabled=domain_enabled)
return VocabularyEngine(vocab, domain)
def test_boundary_aware_replacement(self): def test_boundary_aware_replacement(self):
engine = self._engine( engine = self._engine(
@ -49,5 +50,27 @@ class VocabularyEngineTests(unittest.TestCase):
self.assertLessEqual(len(hotwords), 1024) self.assertLessEqual(len(hotwords), 1024)
self.assertLessEqual(len(prompt), 600) self.assertLessEqual(len(prompt), 600)
def test_domain_inference_general_fallback(self):
engine = self._engine()
result = engine.infer_domain("please call me later")
self.assertEqual(result.name, DOMAIN_GENERAL)
self.assertEqual(result.confidence, 0.0)
def test_domain_inference_for_technical_text(self):
engine = self._engine(terms=["Docker", "Systemd"])
result = engine.infer_domain("restart Docker and systemd service on prod")
self.assertNotEqual(result.name, DOMAIN_GENERAL)
self.assertGreater(result.confidence, 0.0)
def test_domain_inference_can_be_disabled(self):
engine = self._engine(domain_enabled=False)
result = engine.infer_domain("please restart docker")
self.assertEqual(result.name, DOMAIN_GENERAL)
self.assertEqual(result.confidence, 0.0)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()