diff --git a/README.md b/README.md index c7510b1..55cca01 100644 --- a/README.md +++ b/README.md @@ -80,8 +80,7 @@ Create `~/.config/aman/config.json` (or let `aman` create it automatically on fi { "from": "docker", "to": "Docker" } ], "terms": ["Systemd", "Kubernetes"] - }, - "domain_inference": { "enabled": true } + } } ``` @@ -92,6 +91,9 @@ Hotkey notes: - Use one key plus optional modifiers (for example `Cmd+m`, `Super+m`, `Ctrl+space`). - `Super` and `Cmd` are equivalent aliases for the same modifier. +- Invalid hotkey syntax in config prevents startup/reload. +- When `~/.config/aman/pipelines.py` exists, hotkeys come from `HOTKEY_PIPELINES`. +- `daemon.hotkey` is used as the fallback/default hotkey only when no pipelines file is present. AI cleanup is always enabled and uses the locked local Llama-3.2-3B GGUF model downloaded to `~/.cache/aman/models/` during daemon initialization. @@ -107,11 +109,6 @@ Vocabulary correction: - Wildcards are intentionally rejected (`*`, `?`, `[`, `]`, `{`, `}`) to avoid ambiguous rules. - Rules are deduplicated case-insensitively; conflicting replacements are rejected. -Domain inference: - -- Domain context is advisory only and is used to improve cleanup prompts. -- When confidence is low, it falls back to `general` context. - STT hinting: - Vocabulary is passed to Whisper as `hotwords`/`initial_prompt` only when those @@ -134,6 +131,12 @@ systemctl --user enable --now aman - Press it again to stop and run STT. - Press `Esc` while recording to cancel without processing. - Transcript contents are logged only when `-v/--verbose` is used. +- Config changes are hot-reloaded automatically (polled every 1 second). +- `~/.config/aman/pipelines.py` changes are hot-reloaded automatically (polled every 1 second). +- Send `SIGHUP` to force an immediate reload of config and pipelines: + `systemctl --user kill -s HUP aman` (or send `HUP` to the process directly). +- Reloads are applied when the daemon is idle; invalid updates are rejected and the last valid config stays active. +- Reload success/failure is logged, and desktop notifications are shown when available. Wayland note: @@ -149,6 +152,77 @@ AI processing: - Local llama.cpp model only (no remote provider configuration). +## Pipelines API + +`aman` is split into: + +- shell daemon: hotkeys, recording/cancel, and desktop injection +- pipeline engine: `lib.transcribe(...)` and `lib.llm(...)` +- pipeline implementation: Python callables mapped per hotkey + +Pipeline file path: + +- `~/.config/aman/pipelines.py` +- You can start from [`pipelines.example.py`](./pipelines.example.py). +- If `pipelines.py` is missing, `aman` uses a built-in reference pipeline bound to `daemon.hotkey`. +- If `pipelines.py` exists but is invalid, startup fails fast. +- Pipelines are hot-reloaded automatically when the module file changes. +- Send `SIGHUP` to force an immediate reload of both config and pipelines. + +Expected module exports: + +```python +HOTKEY_PIPELINES = { + "Super+m": my_pipeline, + "Super+Shift+m": caps_pipeline, +} + +PIPELINE_OPTIONS = { + "Super+Shift+m": {"failure_policy": "strict"}, # optional +} +``` + +Pipeline callable signature: + +```python +def my_pipeline(audio, lib) -> str: + text = lib.transcribe(audio) + context = lib.llm( + system_prompt="context system prompt", + user_prompt=f"Transcript: {text}", + ) + out = lib.llm( + system_prompt="amanuensis prompt", + user_prompt=f"context={context}\ntext={text}", + ) + return out +``` + +`lib` API: + +- `lib.transcribe(audio, hints=None, whisper_opts=None) -> str` +- `lib.llm(system_prompt=..., user_prompt=..., llm_opts=None) -> str` + +Failure policy options: + +- `best_effort` (default): pipeline errors return empty output +- `strict`: pipeline errors abort the current run + +Validation: + +- `HOTKEY_PIPELINES` must be a non-empty dictionary. +- Every hotkey key must be a non-empty string. +- Every pipeline value must be callable. +- `PIPELINE_OPTIONS` must be a dictionary when provided. + +Reference behavior: + +- The built-in fallback pipeline (used when `pipelines.py` is missing) uses `lib.llm(...)` twice: + - first to infer context + - second to run the amanuensis rewrite +- The second pass requests JSON output and expects `{"cleaned_text": "..."}`. +- Deterministic dictionary replacements are then applied as part of that reference implementation. + Control: ```bash diff --git a/pipelines.example.py b/pipelines.example.py new file mode 100644 index 0000000..d5d92f4 --- /dev/null +++ b/pipelines.example.py @@ -0,0 +1,50 @@ +import json + + +CONTEXT_PROMPT = ( + "Return a concise plain-text context hint (max 12 words) " + "for the transcript domain and style." +) + +AMANUENSIS_PROMPT = ( + "Rewrite the transcript into clean prose without changing intent. " + "Return only the final text." +) + + +def default_pipeline(audio, lib) -> str: + text = lib.transcribe(audio) + if not text: + return "" + + context = lib.llm( + system_prompt=CONTEXT_PROMPT, + user_prompt=json.dumps({"transcript": text}, ensure_ascii=False), + llm_opts={"temperature": 0.0}, + ).strip() + + output = lib.llm( + system_prompt=AMANUENSIS_PROMPT, + user_prompt=json.dumps( + {"transcript": text, "context": context}, + ensure_ascii=False, + ), + llm_opts={"temperature": 0.0}, + ) + return output.strip() + + +def caps_pipeline(audio, lib) -> str: + return lib.transcribe(audio).upper() + + +HOTKEY_PIPELINES = { + "Super+m": default_pipeline, + "Super+Shift+m": caps_pipeline, +} + + +PIPELINE_OPTIONS = { + "Super+m": {"failure_policy": "best_effort"}, + "Super+Shift+m": {"failure_policy": "strict"}, +} diff --git a/src/aiprocess.py b/src/aiprocess.py index dec2e4b..d423271 100644 --- a/src/aiprocess.py +++ b/src/aiprocess.py @@ -76,18 +76,41 @@ class LlamaProcessor: if cleaned_dictionary: request_payload["dictionary"] = cleaned_dictionary + content = self.chat( + system_prompt=SYSTEM_PROMPT, + user_prompt=json.dumps(request_payload, ensure_ascii=False), + llm_opts={"temperature": 0.0, "response_format": {"type": "json_object"}}, + ) + return _extract_cleaned_text_from_raw(content) + + def chat( + self, + *, + system_prompt: str, + user_prompt: str, + llm_opts: dict[str, Any] | None = None, + ) -> str: + opts = dict(llm_opts or {}) + temperature = float(opts.pop("temperature", 0.0)) + response_format = opts.pop("response_format", None) + if opts: + unknown = ", ".join(sorted(opts.keys())) + raise ValueError(f"unsupported llm options: {unknown}") + kwargs: dict[str, Any] = { "messages": [ - {"role": "system", "content": SYSTEM_PROMPT}, - {"role": "user", "content": json.dumps(request_payload, ensure_ascii=False)}, + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, ], - "temperature": 0.0, + "temperature": temperature, } - if _supports_response_format(self.client.create_chat_completion): - kwargs["response_format"] = {"type": "json_object"} + if response_format is not None and _supports_response_format( + self.client.create_chat_completion + ): + kwargs["response_format"] = response_format response = self.client.create_chat_completion(**kwargs) - return _extract_cleaned_text(response) + return _extract_chat_text(response) def ensure_model(): @@ -148,6 +171,10 @@ def _extract_chat_text(payload: Any) -> str: def _extract_cleaned_text(payload: Any) -> str: raw = _extract_chat_text(payload) + return _extract_cleaned_text_from_raw(raw) + + +def _extract_cleaned_text_from_raw(raw: str) -> str: try: parsed = json.loads(raw) except json.JSONDecodeError as exc: diff --git a/src/aman.py b/src/aman.py index 0ffbc20..cb79903 100755 --- a/src/aman.py +++ b/src/aman.py @@ -3,6 +3,7 @@ from __future__ import annotations import argparse import errno +import hashlib import inspect import json import logging @@ -12,12 +13,21 @@ import sys import threading import time from pathlib import Path -from typing import Any +from typing import Any, Callable -from aiprocess import LlamaProcessor +from aiprocess import LlamaProcessor, SYSTEM_PROMPT from config import Config, load, redacted_dict -from constants import RECORD_TIMEOUT_SEC, STT_LANGUAGE +from constants import ( + CONFIG_RELOAD_POLL_INTERVAL_SEC, + DEFAULT_CONFIG_PATH, + RECORD_TIMEOUT_SEC, + STT_LANGUAGE, +) from desktop import get_desktop_adapter +from engine import Engine, PipelineBinding, PipelineLib, PipelineOptions +from notify import DesktopNotifier +from pipelines_runtime import DEFAULT_PIPELINES_PATH, fingerprint as pipelines_fingerprint +from pipelines_runtime import load_bindings from recorder import start_recording as start_audio_recording from recorder import stop_recording as stop_audio_recording from vocabulary import VocabularyEngine @@ -55,6 +65,52 @@ def _compute_type(device: str) -> str: return "int8" +def _replacement_view(cfg: Config) -> list[tuple[str, str]]: + return [(item.source, item.target) for item in cfg.vocabulary.replacements] + + +def _config_fields_changed(current: Config, candidate: Config) -> list[str]: + changed: list[str] = [] + if current.daemon.hotkey != candidate.daemon.hotkey: + changed.append("daemon.hotkey") + if current.recording.input != candidate.recording.input: + changed.append("recording.input") + if current.stt.model != candidate.stt.model: + changed.append("stt.model") + if current.stt.device != candidate.stt.device: + changed.append("stt.device") + if current.injection.backend != candidate.injection.backend: + changed.append("injection.backend") + if ( + current.injection.remove_transcription_from_clipboard + != candidate.injection.remove_transcription_from_clipboard + ): + changed.append("injection.remove_transcription_from_clipboard") + if _replacement_view(current) != _replacement_view(candidate): + changed.append("vocabulary.replacements") + if current.vocabulary.terms != candidate.vocabulary.terms: + changed.append("vocabulary.terms") + return changed + + +CONTEXT_SYSTEM_PROMPT = ( + "You infer concise writing context from transcript text.\n" + "Return a short plain-text hint (max 12 words) that describes the likely domain and style.\n" + "Do not add punctuation, quotes, XML tags, or markdown." +) + + +def _extract_cleaned_text_from_llm_raw(raw: str) -> str: + parsed = json.loads(raw) + if isinstance(parsed, str): + return parsed + if isinstance(parsed, dict): + cleaned_text = parsed.get("cleaned_text") + if isinstance(cleaned_text, str): + return cleaned_text + raise RuntimeError("unexpected ai output format: missing cleaned_text") + + class Daemon: def __init__(self, cfg: Config, desktop, *, verbose: bool = False): self.cfg = cfg @@ -66,17 +122,34 @@ class Daemon: self.stream = None self.record = None self.timer: threading.Timer | None = None - self.model = _build_whisper_model( - cfg.stt.model, - cfg.stt.device, - ) + self._config_apply_in_progress = False + self._active_hotkey: str | None = None + self._pending_worker_hotkey: str | None = None + self._pipeline_bindings: dict[str, PipelineBinding] = {} + + self.model = _build_whisper_model(cfg.stt.model, cfg.stt.device) logging.info("initializing ai processor") self.ai_processor = LlamaProcessor(verbose=self.verbose) logging.info("ai processor ready") self.log_transcript = verbose - self.vocabulary = VocabularyEngine(cfg.vocabulary, cfg.domain_inference) + self.vocabulary = VocabularyEngine(cfg.vocabulary) self._stt_hint_kwargs_cache: dict[str, Any] | None = None + self.lib = PipelineLib( + transcribe_fn=self._transcribe_with_options, + llm_fn=self._llm_with_options, + ) + self.engine = Engine(self.lib) + self.set_pipeline_bindings( + { + cfg.daemon.hotkey: PipelineBinding( + hotkey=cfg.daemon.hotkey, + handler=self.build_reference_pipeline(), + options=PipelineOptions(failure_policy="best_effort"), + ) + } + ) + def set_state(self, state: str): with self.lock: prev = self.state @@ -93,16 +166,74 @@ class Daemon: def request_shutdown(self): self._shutdown_requested.set() - def toggle(self): + def current_hotkeys(self) -> list[str]: + with self.lock: + return list(self._pipeline_bindings.keys()) + + def set_pipeline_bindings(self, bindings: dict[str, PipelineBinding]) -> None: + if not bindings: + raise ValueError("at least one pipeline binding is required") + with self.lock: + self._pipeline_bindings = dict(bindings) + if self._active_hotkey and self._active_hotkey not in self._pipeline_bindings: + self._active_hotkey = None + + def build_hotkey_callbacks( + self, + hotkeys: list[str], + *, + dry_run: bool, + ) -> dict[str, Callable[[], None]]: + callbacks: dict[str, Callable[[], None]] = {} + for hotkey in hotkeys: + if dry_run: + callbacks[hotkey] = (lambda hk=hotkey: logging.info("hotkey pressed: %s (dry-run)", hk)) + else: + callbacks[hotkey] = (lambda hk=hotkey: self.handle_hotkey(hk)) + return callbacks + + def apply_pipeline_bindings( + self, + bindings: dict[str, PipelineBinding], + callbacks: dict[str, Callable[[], None]], + ) -> tuple[str, str]: + with self.lock: + if self._shutdown_requested.is_set(): + return "deferred", "shutdown in progress" + if self._config_apply_in_progress: + return "deferred", "reload in progress" + if self.state != State.IDLE: + return "deferred", f"daemon is busy ({self.state})" + self._config_apply_in_progress = True + try: + self.desktop.set_hotkeys(callbacks) + with self.lock: + if self._shutdown_requested.is_set(): + return "deferred", "shutdown in progress" + if self.state != State.IDLE: + return "deferred", f"daemon is busy ({self.state})" + self._pipeline_bindings = dict(bindings) + return "applied", "" + except Exception as exc: + return "error", str(exc) + finally: + with self.lock: + self._config_apply_in_progress = False + + def handle_hotkey(self, hotkey: str): should_stop = False with self.lock: if self._shutdown_requested.is_set(): logging.info("shutdown in progress, trigger ignored") return + if self._config_apply_in_progress: + logging.info("reload in progress, trigger ignored") + return if self.state == State.IDLE: + self._active_hotkey = hotkey self._start_recording_locked() return - if self.state == State.RECORDING: + if self.state == State.RECORDING and self._active_hotkey == hotkey: should_stop = True else: logging.info("busy (%s), trigger ignored", self.state) @@ -113,6 +244,9 @@ class Daemon: if self.state != State.IDLE: logging.info("busy (%s), trigger ignored", self.state) return + if self._config_apply_in_progress: + logging.info("reload in progress, trigger ignored") + return try: stream, record = start_audio_recording(self.cfg.recording.input) except Exception as exc: @@ -133,10 +267,17 @@ class Daemon: def _timeout_stop(self): self.stop_recording(trigger="timeout") - def _start_stop_worker(self, stream: Any, record: Any, trigger: str, process_audio: bool): + def _start_stop_worker( + self, + stream: Any, + record: Any, + trigger: str, + process_audio: bool, + ): + active_hotkey = self._pending_worker_hotkey or self.cfg.daemon.hotkey threading.Thread( target=self._stop_and_process, - args=(stream, record, trigger, process_audio), + args=(stream, record, trigger, process_audio, active_hotkey), daemon=True, ).start() @@ -145,6 +286,8 @@ class Daemon: return None stream = self.stream record = self.record + active_hotkey = self._active_hotkey or self.cfg.daemon.hotkey + self._active_hotkey = None self.stream = None self.record = None if self.timer: @@ -158,9 +301,17 @@ class Daemon: logging.warning("recording resources are unavailable during stop") self.state = State.IDLE return None - return stream, record + return stream, record, active_hotkey - def _stop_and_process(self, stream: Any, record: Any, trigger: str, process_audio: bool): + def _stop_and_process( + self, + stream: Any, + record: Any, + trigger: str, + process_audio: bool, + active_hotkey: str | None = None, + ): + hotkey = active_hotkey or self.cfg.daemon.hotkey logging.info("stopping recording (%s)", trigger) try: audio = stop_audio_recording(stream, record) @@ -179,44 +330,17 @@ class Daemon: return try: - self.set_state(State.STT) - logging.info("stt started") - text = self._transcribe(audio) + logging.info("pipeline started (%s)", hotkey) + text = self._run_pipeline(audio, hotkey).strip() except Exception as exc: - logging.error("stt failed: %s", exc) + logging.error("pipeline failed: %s", exc) self.set_state(State.IDLE) return - text = (text or "").strip() if not text: self.set_state(State.IDLE) return - if self.log_transcript: - logging.debug("stt: %s", text) - else: - logging.info("stt produced %d chars", len(text)) - - domain = self.vocabulary.infer_domain(text) - if not self._shutdown_requested.is_set(): - self.set_state(State.PROCESSING) - logging.info("ai processing started") - try: - processor = self._get_ai_processor() - ai_text = processor.process( - text, - lang=STT_LANGUAGE, - dictionary_context=self.vocabulary.build_ai_dictionary_context(), - domain_name=domain.name, - domain_confidence=domain.confidence, - ) - if ai_text and ai_text.strip(): - text = ai_text.strip() - except Exception as exc: - logging.error("ai process failed: %s", exc) - - text = self.vocabulary.apply_deterministic_replacements(text).strip() - if self.log_transcript: logging.debug("processed: %s", text) else: @@ -242,13 +366,23 @@ class Daemon: finally: self.set_state(State.IDLE) + def _run_pipeline(self, audio: Any, hotkey: str) -> str: + with self.lock: + binding = self._pipeline_bindings.get(hotkey) + if binding is None and self._pipeline_bindings: + binding = next(iter(self._pipeline_bindings.values())) + if binding is None: + return "" + return self.engine.run(binding, audio) + def stop_recording(self, *, trigger: str = "user", process_audio: bool = True): payload = None with self.lock: payload = self._begin_stop_locked() if payload is None: return - stream, record = payload + stream, record, active_hotkey = payload + self._pending_worker_hotkey = active_hotkey self._start_stop_worker(stream, record, trigger, process_audio) def cancel_recording(self): @@ -270,40 +404,129 @@ class Daemon: time.sleep(0.05) return self.get_state() == State.IDLE + def build_reference_pipeline(self) -> Callable[[Any, PipelineLib], str]: + def _reference_pipeline(audio: Any, lib: PipelineLib) -> str: + text = (lib.transcribe(audio) or "").strip() + if not text: + return "" + dictionary_context = self.vocabulary.build_ai_dictionary_context() + context_prompt = { + "transcript": text, + "dictionary": dictionary_context, + } + context = ( + lib.llm( + system_prompt=CONTEXT_SYSTEM_PROMPT, + user_prompt=json.dumps(context_prompt, ensure_ascii=False), + llm_opts={"temperature": 0.0}, + ) + .strip() + ) + payload: dict[str, Any] = { + "language": STT_LANGUAGE, + "transcript": text, + "context": context, + } + if dictionary_context: + payload["dictionary"] = dictionary_context + ai_raw = lib.llm( + system_prompt=SYSTEM_PROMPT, + user_prompt=json.dumps(payload, ensure_ascii=False), + llm_opts={"temperature": 0.0, "response_format": {"type": "json_object"}}, + ) + cleaned_text = _extract_cleaned_text_from_llm_raw(ai_raw) + return self.vocabulary.apply_deterministic_replacements(cleaned_text).strip() + + return _reference_pipeline + def _transcribe(self, audio) -> str: + return self._transcribe_with_options(audio) + + def _transcribe_with_options( + self, + audio: Any, + *, + hints: list[str] | None = None, + whisper_opts: dict[str, Any] | None = None, + ) -> str: kwargs: dict[str, Any] = { "language": STT_LANGUAGE, "vad_filter": True, } - kwargs.update(self._stt_hint_kwargs()) + kwargs.update(self._stt_hint_kwargs(hints=hints)) + if whisper_opts: + kwargs.update(whisper_opts) segments, _info = self.model.transcribe(audio, **kwargs) parts = [] for seg in segments: text = (seg.text or "").strip() if text: parts.append(text) - return " ".join(parts).strip() + out = " ".join(parts).strip() + if self.log_transcript: + logging.debug("stt: %s", out) + else: + logging.info("stt produced %d chars", len(out)) + return out + + def _llm_with_options( + self, + *, + system_prompt: str, + user_prompt: str, + llm_opts: dict[str, Any] | None = None, + ) -> str: + if self.get_state() != State.PROCESSING: + self.set_state(State.PROCESSING) + logging.info("llm processing started") + processor = self._get_ai_processor() + return processor.chat( + system_prompt=system_prompt, + user_prompt=user_prompt, + llm_opts=llm_opts, + ) def _get_ai_processor(self) -> LlamaProcessor: if self.ai_processor is None: raise RuntimeError("ai processor is not initialized") return self.ai_processor - def _stt_hint_kwargs(self) -> dict[str, Any]: - if self._stt_hint_kwargs_cache is not None: + def _stt_hint_kwargs(self, *, hints: list[str] | None = None) -> dict[str, Any]: + if not hints and self._stt_hint_kwargs_cache is not None: return self._stt_hint_kwargs_cache hotwords, initial_prompt = self.vocabulary.build_stt_hints() + extra_hints = [item.strip() for item in (hints or []) if isinstance(item, str) and item.strip()] + if extra_hints: + words = [item.strip() for item in hotwords.split(",") if item.strip()] + words.extend(extra_hints) + deduped: list[str] = [] + seen: set[str] = set() + for item in words: + key = item.casefold() + if key in seen: + continue + seen.add(key) + deduped.append(item) + merged = ", ".join(deduped) + hotwords = merged[:1024] + if hotwords: + initial_prompt = f"Preferred vocabulary: {hotwords}"[:600] + if not hotwords and not initial_prompt: - self._stt_hint_kwargs_cache = {} - return self._stt_hint_kwargs_cache + if not hints: + self._stt_hint_kwargs_cache = {} + return self._stt_hint_kwargs_cache + return {} try: signature = inspect.signature(self.model.transcribe) except (TypeError, ValueError): logging.debug("stt signature inspection failed; skipping hints") - self._stt_hint_kwargs_cache = {} - return self._stt_hint_kwargs_cache + if not hints: + self._stt_hint_kwargs_cache = {} + return self._stt_hint_kwargs_cache + return {} params = signature.parameters kwargs: dict[str, Any] = {} @@ -313,8 +536,288 @@ class Daemon: kwargs["initial_prompt"] = initial_prompt if not kwargs: logging.debug("stt hint arguments are not supported by this whisper runtime") - self._stt_hint_kwargs_cache = kwargs - return self._stt_hint_kwargs_cache + if not hints: + self._stt_hint_kwargs_cache = kwargs + return self._stt_hint_kwargs_cache + return kwargs + + def try_apply_config(self, candidate: Config) -> tuple[str, list[str], str]: + with self.lock: + if self._shutdown_requested.is_set(): + return "deferred", [], "shutdown in progress" + if self._config_apply_in_progress: + return "deferred", [], "config reload already in progress" + if self.state != State.IDLE: + return "deferred", [], f"daemon is busy ({self.state})" + self._config_apply_in_progress = True + current_cfg = self.cfg + + try: + changed = _config_fields_changed(current_cfg, candidate) + if not changed: + return "applied", [], "" + + next_model = None + if ( + current_cfg.stt.model != candidate.stt.model + or current_cfg.stt.device != candidate.stt.device + ): + try: + next_model = _build_whisper_model(candidate.stt.model, candidate.stt.device) + except Exception as exc: + return "error", [], f"stt model reload failed: {exc}" + + try: + next_vocab = VocabularyEngine(candidate.vocabulary) + except Exception as exc: + return "error", [], f"vocabulary reload failed: {exc}" + + with self.lock: + if self._shutdown_requested.is_set(): + return "deferred", [], "shutdown in progress" + if self.state != State.IDLE: + return "deferred", [], f"daemon is busy ({self.state})" + self.cfg = candidate + if next_model is not None: + self.model = next_model + self.vocabulary = next_vocab + self._stt_hint_kwargs_cache = None + return "applied", changed, "" + finally: + with self.lock: + self._config_apply_in_progress = False + + +class ConfigReloader: + def __init__( + self, + *, + daemon: Daemon, + config_path: Path, + notifier: DesktopNotifier | None = None, + poll_interval_sec: float = CONFIG_RELOAD_POLL_INTERVAL_SEC, + ): + self.daemon = daemon + self.config_path = config_path + self.notifier = notifier + self.poll_interval_sec = poll_interval_sec + self._stop_event = threading.Event() + self._request_lock = threading.Lock() + self._pending_reload_reason = "" + self._pending_force = False + self._last_seen_fingerprint = self._fingerprint() + self._thread: threading.Thread | None = None + + def start(self) -> None: + if self._thread is not None: + return + self._thread = threading.Thread(target=self._run, daemon=True, name="config-reloader") + self._thread.start() + + def stop(self, timeout: float = 2.0) -> None: + self._stop_event.set() + if self._thread is not None: + self._thread.join(timeout=timeout) + self._thread = None + + def request_reload(self, reason: str, *, force: bool = False) -> None: + with self._request_lock: + if self._pending_reload_reason: + self._pending_reload_reason = f"{self._pending_reload_reason}; {reason}" + else: + self._pending_reload_reason = reason + self._pending_force = self._pending_force or force + logging.debug("config reload requested: %s (force=%s)", reason, force) + + def tick(self) -> None: + self._detect_file_change() + self._apply_if_pending() + + def _run(self) -> None: + while not self._stop_event.wait(self.poll_interval_sec): + self.tick() + + def _detect_file_change(self) -> None: + fingerprint = self._fingerprint() + if fingerprint == self._last_seen_fingerprint: + return + self._last_seen_fingerprint = fingerprint + self.request_reload("file change detected", force=False) + + def _apply_if_pending(self) -> None: + with self._request_lock: + reason = self._pending_reload_reason + force_reload = self._pending_force + if not reason: + return + + if self.daemon.get_state() != State.IDLE: + logging.debug("config reload deferred; daemon state=%s", self.daemon.get_state()) + return + + try: + candidate = load(str(self.config_path)) + except Exception as exc: + self._clear_pending() + msg = f"config reload failed ({reason}): {exc}" + logging.error(msg) + self._notify("Config Reload Failed", str(exc), error=True) + return + + status, changed, error = self.daemon.try_apply_config(candidate) + if status == "deferred": + logging.debug("config reload deferred during apply: %s", error) + return + self._clear_pending() + if status == "error": + msg = f"config reload failed ({reason}): {error}" + logging.error(msg) + self._notify("Config Reload Failed", error, error=True) + return + + if changed: + msg = ", ".join(changed) + logging.info("config reloaded (%s): %s", reason, msg) + self._notify("Config Reloaded", f"Applied: {msg}", error=False) + return + if force_reload: + logging.info("config reloaded (%s): no effective changes", reason) + self._notify("Config Reloaded", "No effective changes", error=False) + + def _clear_pending(self) -> None: + with self._request_lock: + self._pending_reload_reason = "" + self._pending_force = False + + def _fingerprint(self) -> str | None: + try: + data = self.config_path.read_bytes() + return hashlib.sha256(data).hexdigest() + except FileNotFoundError: + return None + except OSError as exc: + logging.warning("config fingerprint failed (%s): %s", self.config_path, exc) + return None + + def _notify(self, title: str, body: str, *, error: bool) -> None: + if self.notifier is None: + return + sent = self.notifier.send(title, body, error=error) + if not sent: + logging.debug("desktop notifications unavailable: %s", title) + + +class PipelineReloader: + def __init__( + self, + *, + daemon: Daemon, + pipelines_path: Path, + notifier: DesktopNotifier | None = None, + dry_run: bool = False, + poll_interval_sec: float = CONFIG_RELOAD_POLL_INTERVAL_SEC, + ): + self.daemon = daemon + self.pipelines_path = pipelines_path + self.notifier = notifier + self.dry_run = dry_run + self.poll_interval_sec = poll_interval_sec + self._stop_event = threading.Event() + self._request_lock = threading.Lock() + self._pending_reload_reason = "" + self._pending_force = False + self._last_seen_fingerprint = pipelines_fingerprint(self.pipelines_path) + self._thread: threading.Thread | None = None + + def start(self) -> None: + if self._thread is not None: + return + self._thread = threading.Thread(target=self._run, daemon=True, name="pipeline-reloader") + self._thread.start() + + def stop(self, timeout: float = 2.0) -> None: + self._stop_event.set() + if self._thread is not None: + self._thread.join(timeout=timeout) + self._thread = None + + def request_reload(self, reason: str, *, force: bool = False) -> None: + with self._request_lock: + if self._pending_reload_reason: + self._pending_reload_reason = f"{self._pending_reload_reason}; {reason}" + else: + self._pending_reload_reason = reason + self._pending_force = self._pending_force or force + logging.debug("pipeline reload requested: %s (force=%s)", reason, force) + + def tick(self) -> None: + self._detect_file_change() + self._apply_if_pending() + + def _run(self) -> None: + while not self._stop_event.wait(self.poll_interval_sec): + self.tick() + + def _detect_file_change(self) -> None: + fingerprint = pipelines_fingerprint(self.pipelines_path) + if fingerprint == self._last_seen_fingerprint: + return + self._last_seen_fingerprint = fingerprint + self.request_reload("file change detected", force=False) + + def _apply_if_pending(self) -> None: + with self._request_lock: + reason = self._pending_reload_reason + force_reload = self._pending_force + if not reason: + return + if self.daemon.get_state() != State.IDLE: + logging.debug("pipeline reload deferred; daemon state=%s", self.daemon.get_state()) + return + + try: + bindings = load_bindings( + path=self.pipelines_path, + default_hotkey=self.daemon.cfg.daemon.hotkey, + default_handler_factory=self.daemon.build_reference_pipeline, + ) + callbacks = self.daemon.build_hotkey_callbacks( + list(bindings.keys()), + dry_run=self.dry_run, + ) + except Exception as exc: + self._clear_pending() + logging.error("pipeline reload failed (%s): %s", reason, exc) + self._notify("Pipelines Reload Failed", str(exc), error=True) + return + + status, error = self.daemon.apply_pipeline_bindings(bindings, callbacks) + if status == "deferred": + logging.debug("pipeline reload deferred during apply: %s", error) + return + self._clear_pending() + if status == "error": + logging.error("pipeline reload failed (%s): %s", reason, error) + self._notify("Pipelines Reload Failed", error, error=True) + return + + hotkeys = ", ".join(sorted(bindings.keys())) + logging.info("pipelines reloaded (%s): %s", reason, hotkeys) + self._notify("Pipelines Reloaded", f"Hotkeys: {hotkeys}", error=False) + if force_reload and not bindings: + logging.info("pipelines reloaded (%s): no hotkeys", reason) + + def _clear_pending(self) -> None: + with self._request_lock: + self._pending_reload_reason = "" + self._pending_force = False + + def _notify(self, title: str, body: str, *, error: bool) -> None: + if self.notifier is None: + return + sent = self.notifier.send(title, body, error=error) + if not sent: + logging.debug("desktop notifications unavailable: %s", title) def _read_lock_pid(lock_file) -> str: @@ -366,19 +869,49 @@ def main(): level=logging.DEBUG if args.verbose else logging.INFO, format="aman: %(asctime)s %(levelname)s %(message)s", ) - cfg = load(args.config) + config_path = Path(args.config) if args.config else DEFAULT_CONFIG_PATH + cfg = load(str(config_path)) _LOCK_HANDLE = _lock_single_instance() logging.info("hotkey: %s", cfg.daemon.hotkey) logging.info( "config (%s):\n%s", - args.config or str(Path.home() / ".config" / "aman" / "config.json"), + str(config_path), json.dumps(redacted_dict(cfg), indent=2), ) + config_reloader = None + pipeline_reloader = None try: desktop = get_desktop_adapter() daemon = Daemon(cfg, desktop, verbose=args.verbose) + notifier = DesktopNotifier(app_name="aman") + config_reloader = ConfigReloader( + daemon=daemon, + config_path=config_path, + notifier=notifier, + poll_interval_sec=CONFIG_RELOAD_POLL_INTERVAL_SEC, + ) + pipelines_path = DEFAULT_PIPELINES_PATH + initial_bindings = load_bindings( + path=pipelines_path, + default_hotkey=cfg.daemon.hotkey, + default_handler_factory=daemon.build_reference_pipeline, + ) + initial_callbacks = daemon.build_hotkey_callbacks( + list(initial_bindings.keys()), + dry_run=args.dry_run, + ) + status, error = daemon.apply_pipeline_bindings(initial_bindings, initial_callbacks) + if status != "applied": + raise RuntimeError(f"pipeline setup failed: {error}") + pipeline_reloader = PipelineReloader( + daemon=daemon, + pipelines_path=pipelines_path, + notifier=notifier, + dry_run=args.dry_run, + poll_interval_sec=CONFIG_RELOAD_POLL_INTERVAL_SEC, + ) except Exception as exc: logging.error("startup failed: %s", exc) raise SystemExit(1) @@ -390,6 +923,10 @@ def main(): return shutdown_once.set() logging.info("%s, shutting down", reason) + if config_reloader is not None: + config_reloader.stop(timeout=2.0) + if pipeline_reloader is not None: + pipeline_reloader.stop(timeout=2.0) if not daemon.shutdown(timeout=5.0): logging.warning("timed out waiting for idle state during shutdown") desktop.request_quit() @@ -397,15 +934,30 @@ def main(): def handle_signal(_sig, _frame): threading.Thread(target=shutdown, args=("signal received",), daemon=True).start() + def handle_reload_signal(_sig, _frame): + if shutdown_once.is_set(): + return + if config_reloader is not None: + threading.Thread( + target=lambda: config_reloader.request_reload("signal SIGHUP", force=True), + daemon=True, + ).start() + if pipeline_reloader is not None: + threading.Thread( + target=lambda: pipeline_reloader.request_reload("signal SIGHUP", force=True), + daemon=True, + ).start() + signal.signal(signal.SIGINT, handle_signal) signal.signal(signal.SIGTERM, handle_signal) + signal.signal(signal.SIGHUP, handle_reload_signal) try: - desktop.start_hotkey_listener( - cfg.daemon.hotkey, - lambda: logging.info("hotkey pressed (dry-run)") if args.dry_run else daemon.toggle(), - ) desktop.start_cancel_listener(lambda: daemon.cancel_recording()) + if config_reloader is not None: + config_reloader.start() + if pipeline_reloader is not None: + pipeline_reloader.start() except Exception as exc: logging.error("hotkey setup failed: %s", exc) raise SystemExit(1) @@ -413,6 +965,10 @@ def main(): try: desktop.run_tray(daemon.get_state, lambda: shutdown("quit requested")) finally: + if config_reloader is not None: + config_reloader.stop(timeout=2.0) + if pipeline_reloader is not None: + pipeline_reloader.stop(timeout=2.0) daemon.shutdown(timeout=1.0) diff --git a/src/config.py b/src/config.py index 6236966..2557c4f 100644 --- a/src/config.py +++ b/src/config.py @@ -51,11 +51,6 @@ class VocabularyConfig: terms: list[str] = field(default_factory=list) -@dataclass -class DomainInferenceConfig: - enabled: bool = True - - @dataclass class Config: daemon: DaemonConfig = field(default_factory=DaemonConfig) @@ -63,7 +58,6 @@ class Config: stt: SttConfig = field(default_factory=SttConfig) injection: InjectionConfig = field(default_factory=InjectionConfig) vocabulary: VocabularyConfig = field(default_factory=VocabularyConfig) - domain_inference: DomainInferenceConfig = field(default_factory=DomainInferenceConfig) def load(path: str | None) -> Config: @@ -124,20 +118,8 @@ def validate(cfg: Config) -> None: cfg.vocabulary.replacements = _validate_replacements(cfg.vocabulary.replacements) cfg.vocabulary.terms = _validate_terms(cfg.vocabulary.terms) - if not isinstance(cfg.domain_inference.enabled, bool): - raise ValueError("domain_inference.enabled must be boolean") - def _from_dict(data: dict[str, Any], cfg: Config) -> Config: - if "logging" in data: - raise ValueError("logging section is no longer supported; use -v/--verbose") - if "log_transcript" in data: - raise ValueError("log_transcript is no longer supported; use -v/--verbose") - if "ai" in data: - raise ValueError("ai section is no longer supported") - if "ai_enabled" in data: - raise ValueError("ai_enabled is no longer supported") - has_sections = any( key in data for key in ( @@ -146,7 +128,6 @@ def _from_dict(data: dict[str, Any], cfg: Config) -> Config: "stt", "injection", "vocabulary", - "domain_inference", ) ) if has_sections: @@ -155,7 +136,6 @@ def _from_dict(data: dict[str, Any], cfg: Config) -> Config: stt = _ensure_dict(data.get("stt"), "stt") injection = _ensure_dict(data.get("injection"), "injection") vocabulary = _ensure_dict(data.get("vocabulary"), "vocabulary") - domain_inference = _ensure_dict(data.get("domain_inference"), "domain_inference") if "hotkey" in daemon: cfg.daemon.hotkey = _as_nonempty_str(daemon["hotkey"], "daemon.hotkey") @@ -176,28 +156,7 @@ def _from_dict(data: dict[str, Any], cfg: Config) -> Config: cfg.vocabulary.replacements = _as_replacements(vocabulary["replacements"]) if "terms" in vocabulary: cfg.vocabulary.terms = _as_terms(vocabulary["terms"]) - if "max_rules" in vocabulary: - raise ValueError("vocabulary.max_rules is no longer supported") - if "max_terms" in vocabulary: - raise ValueError("vocabulary.max_terms is no longer supported") - if "enabled" in domain_inference: - cfg.domain_inference.enabled = _as_bool( - domain_inference["enabled"], "domain_inference.enabled" - ) - if "mode" in domain_inference: - raise ValueError("domain_inference.mode is no longer supported") return cfg - - if "hotkey" in data: - cfg.daemon.hotkey = _as_nonempty_str(data["hotkey"], "hotkey") - if "input" in data: - cfg.recording.input = _as_recording_input(data["input"]) - if "whisper_model" in data: - cfg.stt.model = _as_nonempty_str(data["whisper_model"], "whisper_model") - if "whisper_device" in data: - cfg.stt.device = _as_nonempty_str(data["whisper_device"], "whisper_device") - if "injection_backend" in data: - cfg.injection.backend = _as_nonempty_str(data["injection_backend"], "injection_backend") return cfg diff --git a/src/constants.py b/src/constants.py index df06122..7dd942f 100644 --- a/src/constants.py +++ b/src/constants.py @@ -2,6 +2,7 @@ from pathlib import Path DEFAULT_CONFIG_PATH = Path.home() / ".config" / "aman" / "config.json" +CONFIG_RELOAD_POLL_INTERVAL_SEC = 1.0 RECORD_TIMEOUT_SEC = 300 STT_LANGUAGE = "en" TRAY_UPDATE_MS = 250 diff --git a/src/desktop.py b/src/desktop.py index 23ac5f0..1ad650d 100644 --- a/src/desktop.py +++ b/src/desktop.py @@ -5,7 +5,7 @@ from typing import Callable, Protocol class DesktopAdapter(Protocol): - def start_hotkey_listener(self, hotkey: str, callback: Callable[[], None]) -> None: + def set_hotkeys(self, bindings: dict[str, Callable[[], None]]) -> None: raise NotImplementedError def start_cancel_listener(self, callback: Callable[[], None]) -> None: diff --git a/src/desktop_wayland.py b/src/desktop_wayland.py index 1da88a8..871a695 100644 --- a/src/desktop_wayland.py +++ b/src/desktop_wayland.py @@ -4,7 +4,7 @@ from typing import Callable class WaylandAdapter: - def start_hotkey_listener(self, _hotkey: str, _callback: Callable[[], None]) -> None: + def set_hotkeys(self, _bindings: dict[str, Callable[[], None]]) -> None: raise SystemExit("Wayland hotkeys are not supported yet.") def start_cancel_listener(self, _callback: Callable[[], None]) -> None: diff --git a/src/desktop_x11.py b/src/desktop_x11.py index d79ca0d..323371b 100644 --- a/src/desktop_x11.py +++ b/src/desktop_x11.py @@ -4,7 +4,7 @@ import logging import threading import time import warnings -from typing import Callable, Iterable +from typing import Any, Callable, Iterable import gi from Xlib import X, XK, display @@ -43,6 +43,8 @@ class X11Adapter: self.indicator = None self.status_icon = None self.menu = None + self._hotkey_listener_lock = threading.Lock() + self._hotkey_listeners: dict[str, dict[str, Any]] = {} if AppIndicator3 is not None: self.indicator = AppIndicator3.Indicator.new( "aman", @@ -65,15 +67,29 @@ class X11Adapter: if self.menu: self.menu.popup(None, None, None, None, 0, _time) - def start_hotkey_listener(self, hotkey: str, callback: Callable[[], None]) -> None: - mods, keysym = self._parse_hotkey(hotkey) - self._validate_hotkey_registration(mods, keysym) - thread = threading.Thread(target=self._listen, args=(mods, keysym, callback), daemon=True) - thread.start() + def set_hotkeys(self, bindings: dict[str, Callable[[], None]]) -> None: + if not isinstance(bindings, dict): + raise ValueError("bindings must be a dictionary") + next_listeners: dict[str, dict[str, Any]] = {} + try: + for hotkey, callback in bindings.items(): + if not callable(callback): + raise ValueError(f"callback for hotkey {hotkey} must be callable") + next_listeners[hotkey] = self._spawn_hotkey_listener(hotkey, callback) + except Exception: + for listener in next_listeners.values(): + self._stop_hotkey_listener(listener) + raise + + with self._hotkey_listener_lock: + previous = self._hotkey_listeners + self._hotkey_listeners = next_listeners + for listener in previous.values(): + self._stop_hotkey_listener(listener) def start_cancel_listener(self, callback: Callable[[], None]) -> None: mods, keysym = self._parse_hotkey("Escape") - thread = threading.Thread(target=self._listen, args=(mods, keysym, callback), daemon=True) + thread = threading.Thread(target=self._listen, args=(mods, keysym, callback, threading.Event()), daemon=True) thread.start() def inject_text( @@ -127,7 +143,14 @@ class X11Adapter: finally: self.request_quit() - def _listen(self, mods: int, keysym: int, callback: Callable[[], None]) -> None: + def _listen( + self, + mods: int, + keysym: int, + callback: Callable[[], None], + stop_event: threading.Event, + listener_meta: dict[str, Any] | None = None, + ) -> None: disp = None root = None keycode = None @@ -135,14 +158,26 @@ class X11Adapter: disp = display.Display() root = disp.screen().root keycode = self._grab_hotkey(disp, root, mods, keysym) - while True: + if listener_meta is not None: + listener_meta["display"] = disp + listener_meta["root"] = root + listener_meta["keycode"] = keycode + listener_meta["ready"].set() + while not stop_event.is_set(): + if disp.pending_events() == 0: + time.sleep(0.05) + continue ev = disp.next_event() if ev.type == X.KeyPress and ev.detail == keycode: state = ev.state & ~(X.LockMask | X.Mod2Mask) if state == mods: callback() except Exception as exc: - logging.error("hotkey listener stopped: %s", exc) + if listener_meta is not None: + listener_meta["error"] = exc + listener_meta["ready"].set() + if not stop_event.is_set(): + logging.error("hotkey listener stopped: %s", exc) finally: if root is not None and keycode is not None and disp is not None: try: @@ -150,6 +185,13 @@ class X11Adapter: disp.sync() except Exception: pass + if disp is not None: + try: + disp.close() + except Exception: + pass + if listener_meta is not None: + listener_meta["ready"].set() def _parse_hotkey(self, hotkey: str): mods = 0 @@ -185,6 +227,51 @@ class X11Adapter: except Exception: pass + def _spawn_hotkey_listener(self, hotkey: str, callback: Callable[[], None]) -> dict[str, Any]: + mods, keysym = self._parse_hotkey(hotkey) + self._validate_hotkey_registration(mods, keysym) + stop_event = threading.Event() + listener_meta: dict[str, Any] = { + "hotkey": hotkey, + "mods": mods, + "keysym": keysym, + "stop_event": stop_event, + "display": None, + "root": None, + "keycode": None, + "error": None, + "ready": threading.Event(), + } + thread = threading.Thread( + target=self._listen, + args=(mods, keysym, callback, stop_event, listener_meta), + daemon=True, + ) + listener_meta["thread"] = thread + thread.start() + if not listener_meta["ready"].wait(timeout=2.0): + stop_event.set() + raise RuntimeError("hotkey listener setup timed out") + if listener_meta["error"] is not None: + stop_event.set() + raise listener_meta["error"] + if listener_meta["keycode"] is None: + stop_event.set() + raise RuntimeError("hotkey listener failed to initialize") + return listener_meta + + def _stop_hotkey_listener(self, listener: dict[str, Any]) -> None: + listener["stop_event"].set() + disp = listener.get("display") + if disp is not None: + try: + disp.close() + except Exception: + pass + thread = listener.get("thread") + if thread is not None: + thread.join(timeout=1.0) + def _grab_hotkey(self, disp, root, mods, keysym): keycode = disp.keysym_to_keycode(keysym) if keycode == 0: diff --git a/src/engine.py b/src/engine.py new file mode 100644 index 0000000..de54c90 --- /dev/null +++ b/src/engine.py @@ -0,0 +1,86 @@ +from __future__ import annotations + +import logging +from dataclasses import dataclass +from typing import Any, Callable + + +ALLOWED_FAILURE_POLICIES = {"best_effort", "strict"} + + +@dataclass(frozen=True) +class PipelineOptions: + failure_policy: str = "best_effort" + + def __post_init__(self): + policy = (self.failure_policy or "").strip().lower() + if policy not in ALLOWED_FAILURE_POLICIES: + allowed = ", ".join(sorted(ALLOWED_FAILURE_POLICIES)) + raise ValueError(f"failure_policy must be one of: {allowed}") + object.__setattr__(self, "failure_policy", policy) + + +@dataclass(frozen=True) +class PipelineBinding: + hotkey: str + handler: Callable[[Any, "PipelineLib"], str] + options: PipelineOptions + + +class PipelineLib: + def __init__( + self, + *, + transcribe_fn: Callable[[Any], str], + llm_fn: Callable[..., str], + ): + self._transcribe_fn = transcribe_fn + self._llm_fn = llm_fn + + def transcribe( + self, + audio: Any, + *, + hints: list[str] | None = None, + whisper_opts: dict[str, Any] | None = None, + ) -> str: + return self._transcribe_fn( + audio, + hints=hints, + whisper_opts=whisper_opts, + ) + + def llm( + self, + *, + system_prompt: str, + user_prompt: str, + llm_opts: dict[str, Any] | None = None, + ) -> str: + return self._llm_fn( + system_prompt=system_prompt, + user_prompt=user_prompt, + llm_opts=llm_opts, + ) + + +class Engine: + def __init__(self, lib: PipelineLib): + self.lib = lib + + def run(self, binding: PipelineBinding, audio: Any) -> str: + try: + output = binding.handler(audio, self.lib) + except Exception as exc: + if binding.options.failure_policy == "strict": + raise + logging.error("pipeline failed for hotkey %s: %s", binding.hotkey, exc) + return "" + + if output is None: + return "" + if not isinstance(output, str): + raise RuntimeError( + f"pipeline for hotkey {binding.hotkey} returned non-string output" + ) + return output diff --git a/src/notify.py b/src/notify.py new file mode 100644 index 0000000..27469a2 --- /dev/null +++ b/src/notify.py @@ -0,0 +1,40 @@ +from __future__ import annotations + + +class DesktopNotifier: + def __init__(self, app_name: str = "aman"): + self._app_name = app_name + self._notify = None + self._ready = False + self._attempted = False + + def send(self, title: str, body: str, *, error: bool = False) -> bool: + _ = error + notify = self._ensure_backend() + if notify is None: + return False + try: + notification = notify.Notification.new(title, body, None) + notification.show() + return True + except Exception: + return False + + def _ensure_backend(self): + if self._attempted: + return self._notify if self._ready else None + + self._attempted = True + try: + import gi + + gi.require_version("Notify", "0.7") + from gi.repository import Notify # type: ignore[import-not-found] + + if Notify.init(self._app_name): + self._notify = Notify + self._ready = True + except Exception: + self._notify = None + self._ready = False + return self._notify if self._ready else None diff --git a/src/pipelines_runtime.py b/src/pipelines_runtime.py new file mode 100644 index 0000000..8aa7985 --- /dev/null +++ b/src/pipelines_runtime.py @@ -0,0 +1,97 @@ +from __future__ import annotations + +import hashlib +import importlib.util +from pathlib import Path +from types import ModuleType +from typing import Any, Callable + +from engine import PipelineBinding, PipelineOptions + + +DEFAULT_PIPELINES_PATH = Path.home() / ".config" / "aman" / "pipelines.py" + + +def fingerprint(path: Path) -> str | None: + try: + data = path.read_bytes() + except FileNotFoundError: + return None + return hashlib.sha256(data).hexdigest() + + +def load_bindings( + *, + path: Path, + default_hotkey: str, + default_handler_factory: Callable[[], Callable[[Any, Any], str]], +) -> dict[str, PipelineBinding]: + if not path.exists(): + handler = default_handler_factory() + options = PipelineOptions(failure_policy="best_effort") + return { + default_hotkey: PipelineBinding( + hotkey=default_hotkey, + handler=handler, + options=options, + ) + } + + module = _load_module(path) + return _bindings_from_module(module) + + +def _load_module(path: Path) -> ModuleType: + module_name = f"aman_user_pipelines_{hash(path.resolve())}" + spec = importlib.util.spec_from_file_location(module_name, str(path)) + if spec is None or spec.loader is None: + raise ValueError(f"unable to load pipelines module: {path}") + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +def _bindings_from_module(module: ModuleType) -> dict[str, PipelineBinding]: + raw_pipelines = getattr(module, "HOTKEY_PIPELINES", None) + if not isinstance(raw_pipelines, dict): + raise ValueError("HOTKEY_PIPELINES must be a dictionary") + if not raw_pipelines: + raise ValueError("HOTKEY_PIPELINES cannot be empty") + + raw_options = getattr(module, "PIPELINE_OPTIONS", {}) + if raw_options is None: + raw_options = {} + if not isinstance(raw_options, dict): + raise ValueError("PIPELINE_OPTIONS must be a dictionary when provided") + + bindings: dict[str, PipelineBinding] = {} + for key, handler in raw_pipelines.items(): + hotkey = _validate_hotkey(key) + if not callable(handler): + raise ValueError(f"HOTKEY_PIPELINES[{hotkey}] must be callable") + option_data = raw_options.get(hotkey, {}) + options = _parse_options(hotkey, option_data) + bindings[hotkey] = PipelineBinding( + hotkey=hotkey, + handler=handler, + options=options, + ) + return bindings + + +def _validate_hotkey(value: Any) -> str: + if not isinstance(value, str): + raise ValueError("pipeline hotkey keys must be strings") + hotkey = value.strip() + if not hotkey: + raise ValueError("pipeline hotkey keys cannot be empty") + return hotkey + + +def _parse_options(hotkey: str, value: Any) -> PipelineOptions: + if value is None: + return PipelineOptions() + if not isinstance(value, dict): + raise ValueError(f"PIPELINE_OPTIONS[{hotkey}] must be an object") + failure_policy = value.get("failure_policy", "best_effort") + return PipelineOptions(failure_policy=failure_policy) diff --git a/src/vocabulary.py b/src/vocabulary.py index 1a7ebb2..2629eb6 100644 --- a/src/vocabulary.py +++ b/src/vocabulary.py @@ -4,101 +4,7 @@ import re from dataclasses import dataclass from typing import Iterable -from config import DomainInferenceConfig, VocabularyConfig - - -DOMAIN_GENERAL = "general" -DOMAIN_PERSONAL_NAMES = "personal_names" -DOMAIN_SOFTWARE_DEV = "software_dev" -DOMAIN_OPS_INFRA = "ops_infra" -DOMAIN_BUSINESS = "business" -DOMAIN_MEDICAL_LEGAL = "medical_legal" - -DOMAIN_ORDER = ( - DOMAIN_PERSONAL_NAMES, - DOMAIN_SOFTWARE_DEV, - DOMAIN_OPS_INFRA, - DOMAIN_BUSINESS, - DOMAIN_MEDICAL_LEGAL, -) - -DOMAIN_KEYWORDS = { - DOMAIN_SOFTWARE_DEV: { - "api", - "bug", - "code", - "commit", - "docker", - "function", - "git", - "github", - "javascript", - "python", - "refactor", - "repository", - "typescript", - "unit", - "test", - }, - DOMAIN_OPS_INFRA: { - "cluster", - "container", - "deploy", - "deployment", - "incident", - "kubernetes", - "monitoring", - "nginx", - "pod", - "prod", - "service", - "systemd", - "terraform", - }, - DOMAIN_BUSINESS: { - "budget", - "client", - "deadline", - "finance", - "invoice", - "meeting", - "milestone", - "project", - "quarter", - "roadmap", - "sales", - "stakeholder", - }, - DOMAIN_MEDICAL_LEGAL: { - "agreement", - "case", - "claim", - "compliance", - "contract", - "diagnosis", - "liability", - "patient", - "prescription", - "regulation", - "symptom", - "treatment", - }, -} - -DOMAIN_PHRASES = { - DOMAIN_SOFTWARE_DEV: ("pull request", "code review", "integration test"), - DOMAIN_OPS_INFRA: ("on call", "service restart", "roll back"), - DOMAIN_BUSINESS: ("follow up", "action items", "meeting notes"), - DOMAIN_MEDICAL_LEGAL: ("terms and conditions", "medical record", "legal review"), -} - -GREETING_TOKENS = {"hello", "hi", "hey", "good morning", "good afternoon", "good evening"} - - -@dataclass(frozen=True) -class DomainResult: - name: str - confidence: float +from config import VocabularyConfig @dataclass(frozen=True) @@ -108,10 +14,9 @@ class _ReplacementView: class VocabularyEngine: - def __init__(self, vocab_cfg: VocabularyConfig, domain_cfg: DomainInferenceConfig): + def __init__(self, vocab_cfg: VocabularyConfig): self._replacements = [_ReplacementView(r.source, r.target) for r in vocab_cfg.replacements] self._terms = list(vocab_cfg.terms) - self._domain_enabled = bool(domain_cfg.enabled) self._replacement_map = { _normalize_key(rule.source): rule.target for rule in self._replacements @@ -161,55 +66,6 @@ class VocabularyEngine: used += addition return "\n".join(out) - def infer_domain(self, text: str) -> DomainResult: - if not self._domain_enabled: - return DomainResult(name=DOMAIN_GENERAL, confidence=0.0) - - normalized = text.casefold() - tokens = re.findall(r"[a-z0-9+#./_-]+", normalized) - if not tokens: - return DomainResult(name=DOMAIN_GENERAL, confidence=0.0) - - scores = {domain: 0 for domain in DOMAIN_ORDER} - for token in tokens: - for domain, keywords in DOMAIN_KEYWORDS.items(): - if token in keywords: - scores[domain] += 2 - - for domain, phrases in DOMAIN_PHRASES.items(): - for phrase in phrases: - if phrase in normalized: - scores[domain] += 2 - - if any(token in GREETING_TOKENS for token in tokens): - scores[DOMAIN_PERSONAL_NAMES] += 1 - - # Boost domains from configured dictionary terms and replacement targets. - dictionary_tokens = self._dictionary_tokens() - for token in dictionary_tokens: - for domain, keywords in DOMAIN_KEYWORDS.items(): - if token in keywords and token in tokens: - scores[domain] += 1 - - top_domain = DOMAIN_GENERAL - top_score = 0 - total_score = 0 - for domain in DOMAIN_ORDER: - score = scores[domain] - total_score += score - if score > top_score: - top_score = score - top_domain = domain - - if top_score < 2 or total_score == 0: - return DomainResult(name=DOMAIN_GENERAL, confidence=0.0) - - confidence = top_score / total_score - if confidence < 0.45: - return DomainResult(name=DOMAIN_GENERAL, confidence=0.0) - - return DomainResult(name=top_domain, confidence=round(confidence, 2)) - def _build_stt_hotwords(self, *, limit: int, char_budget: int) -> str: items = _dedupe_preserve_order( [rule.target for rule in self._replacements] + self._terms @@ -236,19 +92,6 @@ class VocabularyEngine: return "" return prefix + hotwords - def _dictionary_tokens(self) -> set[str]: - values: list[str] = [] - for rule in self._replacements: - values.append(rule.source) - values.append(rule.target) - values.extend(self._terms) - - tokens: set[str] = set() - for value in values: - for token in re.findall(r"[a-z0-9+#./_-]+", value.casefold()): - tokens.add(token) - return tokens - def _build_replacement_pattern(sources: Iterable[str]) -> re.Pattern[str] | None: unique_sources = _dedupe_preserve_order(list(sources)) diff --git a/tests/test_aman.py b/tests/test_aman.py index d9a9867..d0147ed 100644 --- a/tests/test_aman.py +++ b/tests/test_aman.py @@ -1,3 +1,4 @@ +import json import os import sys import tempfile @@ -11,14 +12,25 @@ if str(SRC) not in sys.path: sys.path.insert(0, str(SRC)) import aman -from config import Config, VocabularyReplacement +from config import Config, VocabularyReplacement, redacted_dict +from engine import PipelineBinding, PipelineOptions class FakeDesktop: def __init__(self): self.inject_calls = [] + self.hotkey_updates = [] + self.hotkeys = {} + self.cancel_callback = None self.quit_calls = 0 + def set_hotkeys(self, bindings): + self.hotkeys = dict(bindings) + self.hotkey_updates.append(tuple(sorted(bindings.keys()))) + + def start_cancel_listener(self, callback): + self.cancel_callback = callback + def inject_text( self, text: str, @@ -76,12 +88,30 @@ class FakeAIProcessor: def process(self, text, lang="en", **_kwargs): return text + def chat(self, *, system_prompt, user_prompt, llm_opts=None): + _ = system_prompt + opts = llm_opts or {} + if "response_format" in opts: + payload = json.loads(user_prompt) + transcript = payload.get("transcript", "") + return json.dumps({"cleaned_text": transcript}) + return "general" + class FakeAudio: def __init__(self, size: int): self.size = size +class FakeNotifier: + def __init__(self): + self.events = [] + + def send(self, title, body, *, error=False): + self.events.append((title, body, error)) + return True + + class DaemonTests(unittest.TestCase): def _config(self) -> Config: cfg = Config() @@ -103,19 +133,23 @@ class DaemonTests(unittest.TestCase): @patch("aman.stop_audio_recording", return_value=FakeAudio(8)) @patch("aman.start_audio_recording", return_value=(object(), object())) - def test_toggle_start_stop_injects_text(self, _start_mock, _stop_mock): + def test_hotkey_start_stop_injects_text(self, _start_mock, _stop_mock): desktop = FakeDesktop() daemon = self._build_daemon(desktop, FakeModel(), verbose=False) daemon._start_stop_worker = ( lambda stream, record, trigger, process_audio: daemon._stop_and_process( - stream, record, trigger, process_audio + stream, + record, + trigger, + process_audio, + daemon._pending_worker_hotkey, ) ) - daemon.toggle() + daemon.handle_hotkey(daemon.cfg.daemon.hotkey) self.assertEqual(daemon.get_state(), aman.State.RECORDING) - daemon.toggle() + daemon.handle_hotkey(daemon.cfg.daemon.hotkey) self.assertEqual(daemon.get_state(), aman.State.IDLE) self.assertEqual(desktop.inject_calls, [("hello world", "clipboard", False)]) @@ -131,7 +165,7 @@ class DaemonTests(unittest.TestCase): ) ) - daemon.toggle() + daemon.handle_hotkey(daemon.cfg.daemon.hotkey) self.assertEqual(daemon.get_state(), aman.State.RECORDING) self.assertTrue(daemon.shutdown(timeout=0.2)) @@ -153,8 +187,8 @@ class DaemonTests(unittest.TestCase): ) ) - daemon.toggle() - daemon.toggle() + daemon.handle_hotkey(daemon.cfg.daemon.hotkey) + daemon.handle_hotkey(daemon.cfg.daemon.hotkey) self.assertEqual(desktop.inject_calls, [("good morning Marta", "clipboard", False)]) @@ -223,8 +257,8 @@ class DaemonTests(unittest.TestCase): ) ) - daemon.toggle() - daemon.toggle() + daemon.handle_hotkey(daemon.cfg.daemon.hotkey) + daemon.handle_hotkey(daemon.cfg.daemon.hotkey) self.assertEqual(desktop.inject_calls, [("hello world", "clipboard", True)]) @@ -239,6 +273,161 @@ class DaemonTests(unittest.TestCase): any("DEBUG:root:state: idle -> recording" in line for line in logs.output) ) + def test_hotkey_dispatches_to_matching_pipeline(self): + desktop = FakeDesktop() + daemon = self._build_daemon(desktop, FakeModel(), verbose=False) + daemon.set_pipeline_bindings( + { + "Super+m": PipelineBinding( + hotkey="Super+m", + handler=lambda _audio, _lib: "default", + options=PipelineOptions(failure_policy="best_effort"), + ), + "Super+Shift+m": PipelineBinding( + hotkey="Super+Shift+m", + handler=lambda _audio, _lib: "caps", + options=PipelineOptions(failure_policy="best_effort"), + ), + } + ) + + out = daemon._run_pipeline(object(), "Super+Shift+m") + self.assertEqual(out, "caps") + + def test_try_apply_config_applies_new_runtime_settings(self): + desktop = FakeDesktop() + cfg = self._config() + daemon = self._build_daemon(desktop, FakeModel(), cfg=cfg, verbose=False) + daemon._stt_hint_kwargs_cache = {"hotwords": "old"} + + candidate = Config() + candidate.injection.backend = "injection" + candidate.vocabulary.replacements = [ + VocabularyReplacement(source="Martha", target="Marta"), + ] + + status, changed, error = daemon.try_apply_config(candidate) + + self.assertEqual(status, "applied") + self.assertEqual(error, "") + self.assertIn("injection.backend", changed) + self.assertIn("vocabulary.replacements", changed) + self.assertEqual(daemon.cfg.injection.backend, "injection") + self.assertEqual( + daemon.vocabulary.apply_deterministic_replacements("Martha"), + "Marta", + ) + self.assertIsNone(daemon._stt_hint_kwargs_cache) + + def test_try_apply_config_reloads_stt_model(self): + desktop = FakeDesktop() + cfg = self._config() + daemon = self._build_daemon(desktop, FakeModel(), cfg=cfg, verbose=False) + + candidate = Config() + candidate.stt.model = "small" + next_model = FakeModel(text="from-new-model") + with patch("aman._build_whisper_model", return_value=next_model): + status, changed, error = daemon.try_apply_config(candidate) + + self.assertEqual(status, "applied") + self.assertEqual(error, "") + self.assertIn("stt.model", changed) + self.assertIs(daemon.model, next_model) + + def test_try_apply_config_is_deferred_while_busy(self): + desktop = FakeDesktop() + daemon = self._build_daemon(desktop, FakeModel(), verbose=False) + daemon.set_state(aman.State.RECORDING) + + status, changed, error = daemon.try_apply_config(Config()) + + self.assertEqual(status, "deferred") + self.assertEqual(changed, []) + self.assertIn("busy", error) + + +class ConfigReloaderTests(unittest.TestCase): + def _write_config(self, path: Path, cfg: Config): + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(json.dumps(redacted_dict(cfg)), encoding="utf-8") + + def _daemon_for_path(self, path: Path) -> tuple[aman.Daemon, FakeDesktop]: + cfg = Config() + self._write_config(path, cfg) + desktop = FakeDesktop() + with patch("aman._build_whisper_model", return_value=FakeModel()), patch( + "aman.LlamaProcessor", return_value=FakeAIProcessor() + ): + daemon = aman.Daemon(cfg, desktop, verbose=False) + return daemon, desktop + + def test_reloader_applies_changed_config(self): + with tempfile.TemporaryDirectory() as td: + path = Path(td) / "config.json" + daemon, _desktop = self._daemon_for_path(path) + notifier = FakeNotifier() + reloader = aman.ConfigReloader( + daemon=daemon, + config_path=path, + notifier=notifier, + poll_interval_sec=0.01, + ) + + updated = Config() + updated.injection.backend = "injection" + self._write_config(path, updated) + reloader.tick() + + self.assertEqual(daemon.cfg.injection.backend, "injection") + self.assertTrue(any(evt[0] == "Config Reloaded" and not evt[2] for evt in notifier.events)) + + def test_reloader_keeps_last_good_config_when_invalid(self): + with tempfile.TemporaryDirectory() as td: + path = Path(td) / "config.json" + daemon, _desktop = self._daemon_for_path(path) + notifier = FakeNotifier() + reloader = aman.ConfigReloader( + daemon=daemon, + config_path=path, + notifier=notifier, + poll_interval_sec=0.01, + ) + + path.write_text('{"injection":{"backend":"invalid"}}', encoding="utf-8") + reloader.tick() + self.assertEqual(daemon.cfg.injection.backend, "clipboard") + fail_events = [evt for evt in notifier.events if evt[0] == "Config Reload Failed"] + self.assertEqual(len(fail_events), 1) + + reloader.tick() + fail_events = [evt for evt in notifier.events if evt[0] == "Config Reload Failed"] + self.assertEqual(len(fail_events), 1) + + def test_reloader_defers_apply_until_idle(self): + with tempfile.TemporaryDirectory() as td: + path = Path(td) / "config.json" + daemon, _desktop = self._daemon_for_path(path) + notifier = FakeNotifier() + reloader = aman.ConfigReloader( + daemon=daemon, + config_path=path, + notifier=notifier, + poll_interval_sec=0.01, + ) + + updated = Config() + updated.injection.backend = "injection" + self._write_config(path, updated) + + daemon.set_state(aman.State.RECORDING) + reloader.tick() + self.assertEqual(daemon.cfg.injection.backend, "clipboard") + + daemon.set_state(aman.State.IDLE) + reloader.tick() + self.assertEqual(daemon.cfg.injection.backend, "injection") + class LockTests(unittest.TestCase): def test_lock_rejects_second_instance(self): diff --git a/tests/test_config.py b/tests/test_config.py index f65a9f8..1e6dd13 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -26,7 +26,6 @@ class ConfigTests(unittest.TestCase): self.assertFalse(cfg.injection.remove_transcription_from_clipboard) self.assertEqual(cfg.vocabulary.replacements, []) self.assertEqual(cfg.vocabulary.terms, []) - self.assertTrue(cfg.domain_inference.enabled) self.assertTrue(missing.exists()) written = json.loads(missing.read_text(encoding="utf-8")) @@ -48,7 +47,6 @@ class ConfigTests(unittest.TestCase): ], "terms": ["Systemd", "Kubernetes"], }, - "domain_inference": {"enabled": True}, } with tempfile.TemporaryDirectory() as td: path = Path(td) / "config.json" @@ -66,7 +64,6 @@ class ConfigTests(unittest.TestCase): self.assertEqual(cfg.vocabulary.replacements[0].source, "Martha") self.assertEqual(cfg.vocabulary.replacements[0].target, "Marta") self.assertEqual(cfg.vocabulary.terms, ["Systemd", "Kubernetes"]) - self.assertTrue(cfg.domain_inference.enabled) def test_super_modifier_hotkey_is_valid(self): payload = {"daemon": {"hotkey": "Super+m"}} @@ -98,28 +95,6 @@ class ConfigTests(unittest.TestCase): ): load(str(path)) - def test_loads_legacy_keys(self): - payload = { - "hotkey": "Alt+m", - "input": "Mic", - "whisper_model": "tiny", - "whisper_device": "cpu", - "injection_backend": "clipboard", - } - with tempfile.TemporaryDirectory() as td: - path = Path(td) / "config.json" - path.write_text(json.dumps(payload), encoding="utf-8") - - cfg = load(str(path)) - - self.assertEqual(cfg.daemon.hotkey, "Alt+m") - self.assertEqual(cfg.recording.input, "Mic") - self.assertEqual(cfg.stt.model, "tiny") - self.assertEqual(cfg.stt.device, "cpu") - self.assertEqual(cfg.injection.backend, "clipboard") - self.assertFalse(cfg.injection.remove_transcription_from_clipboard) - self.assertEqual(cfg.vocabulary.replacements, []) - def test_invalid_injection_backend_raises(self): payload = {"injection": {"backend": "invalid"}} with tempfile.TemporaryDirectory() as td: @@ -138,41 +113,20 @@ class ConfigTests(unittest.TestCase): with self.assertRaisesRegex(ValueError, "injection.remove_transcription_from_clipboard"): load(str(path)) - def test_removed_ai_section_raises(self): - payload = {"ai": {"enabled": True}} + def test_unknown_top_level_fields_are_ignored(self): + payload = { + "custom_a": {"enabled": True}, + "custom_b": {"nested": "value"}, + "custom_c": 123, + } with tempfile.TemporaryDirectory() as td: path = Path(td) / "config.json" path.write_text(json.dumps(payload), encoding="utf-8") - with self.assertRaisesRegex(ValueError, "ai section is no longer supported"): - load(str(path)) + cfg = load(str(path)) - def test_removed_legacy_ai_enabled_raises(self): - payload = {"ai_enabled": True} - with tempfile.TemporaryDirectory() as td: - path = Path(td) / "config.json" - path.write_text(json.dumps(payload), encoding="utf-8") - - with self.assertRaisesRegex(ValueError, "ai_enabled is no longer supported"): - load(str(path)) - - def test_removed_logging_section_raises(self): - payload = {"logging": {"log_transcript": True}} - with tempfile.TemporaryDirectory() as td: - path = Path(td) / "config.json" - path.write_text(json.dumps(payload), encoding="utf-8") - - with self.assertRaisesRegex(ValueError, "no longer supported"): - load(str(path)) - - def test_removed_legacy_log_transcript_raises(self): - payload = {"log_transcript": True} - with tempfile.TemporaryDirectory() as td: - path = Path(td) / "config.json" - path.write_text(json.dumps(payload), encoding="utf-8") - - with self.assertRaisesRegex(ValueError, "no longer supported"): - load(str(path)) + self.assertEqual(cfg.daemon.hotkey, "Cmd+m") + self.assertEqual(cfg.injection.backend, "clipboard") def test_conflicting_replacements_raise(self): payload = { @@ -224,32 +178,15 @@ class ConfigTests(unittest.TestCase): with self.assertRaisesRegex(ValueError, "wildcard"): load(str(path)) - def test_removed_domain_mode_raises(self): - payload = {"domain_inference": {"mode": "heuristic"}} + def test_unknown_vocabulary_fields_are_ignored(self): + payload = {"vocabulary": {"custom_limit": 100, "custom_extra": 200, "terms": ["Docker"]}} with tempfile.TemporaryDirectory() as td: path = Path(td) / "config.json" path.write_text(json.dumps(payload), encoding="utf-8") - with self.assertRaisesRegex(ValueError, "domain_inference.mode is no longer supported"): - load(str(path)) + cfg = load(str(path)) - def test_removed_vocabulary_max_rules_raises(self): - payload = {"vocabulary": {"max_rules": 100}} - with tempfile.TemporaryDirectory() as td: - path = Path(td) / "config.json" - path.write_text(json.dumps(payload), encoding="utf-8") - - with self.assertRaisesRegex(ValueError, "vocabulary.max_rules is no longer supported"): - load(str(path)) - - def test_removed_vocabulary_max_terms_raises(self): - payload = {"vocabulary": {"max_terms": 100}} - with tempfile.TemporaryDirectory() as td: - path = Path(td) / "config.json" - path.write_text(json.dumps(payload), encoding="utf-8") - - with self.assertRaisesRegex(ValueError, "vocabulary.max_terms is no longer supported"): - load(str(path)) + self.assertEqual(cfg.vocabulary.terms, ["Docker"]) if __name__ == "__main__": diff --git a/tests/test_engine.py b/tests/test_engine.py new file mode 100644 index 0000000..4253479 --- /dev/null +++ b/tests/test_engine.py @@ -0,0 +1,92 @@ +import sys +import unittest +from pathlib import Path + +ROOT = Path(__file__).resolve().parents[1] +SRC = ROOT / "src" +if str(SRC) not in sys.path: + sys.path.insert(0, str(SRC)) + +from engine import Engine, PipelineBinding, PipelineLib, PipelineOptions + + +class EngineTests(unittest.TestCase): + def test_best_effort_pipeline_failure_returns_empty_string(self): + lib = PipelineLib( + transcribe_fn=lambda *_args, **_kwargs: "text", + llm_fn=lambda **_kwargs: "result", + ) + engine = Engine(lib) + + def failing_pipeline(_audio, _lib): + raise RuntimeError("boom") + + binding = PipelineBinding( + hotkey="Cmd+m", + handler=failing_pipeline, + options=PipelineOptions(failure_policy="best_effort"), + ) + out = engine.run(binding, object()) + self.assertEqual(out, "") + + def test_strict_pipeline_failure_raises(self): + lib = PipelineLib( + transcribe_fn=lambda *_args, **_kwargs: "text", + llm_fn=lambda **_kwargs: "result", + ) + engine = Engine(lib) + + def failing_pipeline(_audio, _lib): + raise RuntimeError("boom") + + binding = PipelineBinding( + hotkey="Cmd+m", + handler=failing_pipeline, + options=PipelineOptions(failure_policy="strict"), + ) + with self.assertRaisesRegex(RuntimeError, "boom"): + engine.run(binding, object()) + + def test_pipeline_lib_forwards_arguments(self): + seen = {} + + def transcribe_fn(audio, *, hints=None, whisper_opts=None): + seen["audio"] = audio + seen["hints"] = hints + seen["whisper_opts"] = whisper_opts + return "hello" + + def llm_fn(*, system_prompt, user_prompt, llm_opts=None): + seen["system_prompt"] = system_prompt + seen["user_prompt"] = user_prompt + seen["llm_opts"] = llm_opts + return "world" + + lib = PipelineLib( + transcribe_fn=transcribe_fn, + llm_fn=llm_fn, + ) + + audio = object() + self.assertEqual( + lib.transcribe(audio, hints=["Docker"], whisper_opts={"vad_filter": True}), + "hello", + ) + self.assertEqual( + lib.llm( + system_prompt="sys", + user_prompt="user", + llm_opts={"temperature": 0.2}, + ), + "world", + ) + self.assertIs(seen["audio"], audio) + self.assertEqual(seen["hints"], ["Docker"]) + self.assertEqual(seen["whisper_opts"], {"vad_filter": True}) + self.assertEqual(seen["system_prompt"], "sys") + self.assertEqual(seen["user_prompt"], "user") + self.assertEqual(seen["llm_opts"], {"temperature": 0.2}) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_pipelines_runtime.py b/tests/test_pipelines_runtime.py new file mode 100644 index 0000000..4bc644b --- /dev/null +++ b/tests/test_pipelines_runtime.py @@ -0,0 +1,109 @@ +import tempfile +import textwrap +import sys +import unittest +from pathlib import Path + +ROOT = Path(__file__).resolve().parents[1] +SRC = ROOT / "src" +if str(SRC) not in sys.path: + sys.path.insert(0, str(SRC)) + +from pipelines_runtime import fingerprint, load_bindings + + +class PipelinesRuntimeTests(unittest.TestCase): + def test_missing_file_uses_default_binding(self): + with tempfile.TemporaryDirectory() as td: + path = Path(td) / "pipelines.py" + called = {"count": 0} + + def factory(): + called["count"] += 1 + + def handler(_audio, _lib): + return "ok" + + return handler + + bindings = load_bindings( + path=path, + default_hotkey="Cmd+m", + default_handler_factory=factory, + ) + + self.assertEqual(list(bindings.keys()), ["Cmd+m"]) + self.assertEqual(called["count"], 1) + + def test_loads_module_bindings_and_options(self): + with tempfile.TemporaryDirectory() as td: + path = Path(td) / "pipelines.py" + path.write_text( + textwrap.dedent( + """ + def p1(audio, lib): + return "one" + + def p2(audio, lib): + return "two" + + HOTKEY_PIPELINES = { + "Cmd+m": p1, + "Ctrl+space": p2, + } + + PIPELINE_OPTIONS = { + "Ctrl+space": {"failure_policy": "strict"}, + } + """ + ), + encoding="utf-8", + ) + + bindings = load_bindings( + path=path, + default_hotkey="Cmd+m", + default_handler_factory=lambda: (lambda _audio, _lib: "default"), + ) + + self.assertEqual(sorted(bindings.keys()), ["Cmd+m", "Ctrl+space"]) + self.assertEqual(bindings["Cmd+m"].options.failure_policy, "best_effort") + self.assertEqual(bindings["Ctrl+space"].options.failure_policy, "strict") + + def test_invalid_options_fail(self): + with tempfile.TemporaryDirectory() as td: + path = Path(td) / "pipelines.py" + path.write_text( + textwrap.dedent( + """ + def p(audio, lib): + return "ok" + HOTKEY_PIPELINES = {"Cmd+m": p} + PIPELINE_OPTIONS = {"Cmd+m": {"failure_policy": "invalid"}} + """ + ), + encoding="utf-8", + ) + + with self.assertRaisesRegex(ValueError, "failure_policy"): + load_bindings( + path=path, + default_hotkey="Cmd+m", + default_handler_factory=lambda: (lambda _audio, _lib: "default"), + ) + + def test_fingerprint_changes_with_content(self): + with tempfile.TemporaryDirectory() as td: + path = Path(td) / "pipelines.py" + self.assertIsNone(fingerprint(path)) + path.write_text("HOTKEY_PIPELINES = {}", encoding="utf-8") + first = fingerprint(path) + path.write_text("HOTKEY_PIPELINES = {'Cmd+m': lambda a, l: ''}", encoding="utf-8") + second = fingerprint(path) + self.assertIsNotNone(first) + self.assertIsNotNone(second) + self.assertNotEqual(first, second) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_vocabulary.py b/tests/test_vocabulary.py index bc33ea6..fd66a17 100644 --- a/tests/test_vocabulary.py +++ b/tests/test_vocabulary.py @@ -7,18 +7,17 @@ SRC = ROOT / "src" if str(SRC) not in sys.path: sys.path.insert(0, str(SRC)) -from config import DomainInferenceConfig, VocabularyConfig, VocabularyReplacement -from vocabulary import DOMAIN_GENERAL, VocabularyEngine +from config import VocabularyConfig, VocabularyReplacement +from vocabulary import VocabularyEngine class VocabularyEngineTests(unittest.TestCase): - def _engine(self, replacements=None, terms=None, domain_enabled=True): + def _engine(self, replacements=None, terms=None): vocab = VocabularyConfig( replacements=replacements or [], terms=terms or [], ) - domain = DomainInferenceConfig(enabled=domain_enabled) - return VocabularyEngine(vocab, domain) + return VocabularyEngine(vocab) def test_boundary_aware_replacement(self): engine = self._engine( @@ -50,27 +49,5 @@ class VocabularyEngineTests(unittest.TestCase): self.assertLessEqual(len(hotwords), 1024) self.assertLessEqual(len(prompt), 600) - def test_domain_inference_general_fallback(self): - engine = self._engine() - result = engine.infer_domain("please call me later") - - self.assertEqual(result.name, DOMAIN_GENERAL) - self.assertEqual(result.confidence, 0.0) - - def test_domain_inference_for_technical_text(self): - engine = self._engine(terms=["Docker", "Systemd"]) - result = engine.infer_domain("restart Docker and systemd service on prod") - - self.assertNotEqual(result.name, DOMAIN_GENERAL) - self.assertGreater(result.confidence, 0.0) - - def test_domain_inference_can_be_disabled(self): - engine = self._engine(domain_enabled=False) - result = engine.infer_domain("please restart docker") - - self.assertEqual(result.name, DOMAIN_GENERAL) - self.assertEqual(result.confidence, 0.0) - - if __name__ == "__main__": unittest.main()