From 5b38cc7dcdec13a9118f1bf0ca584df124d5e889 Mon Sep 17 00:00:00 2001 From: Thales Maciel Date: Thu, 26 Feb 2026 12:54:47 -0300 Subject: [PATCH] Revert "Add pipeline engine and remove legacy compatibility paths" --- README.md | 88 +---- pipelines.example.py | 50 --- src/aiprocess.py | 39 +- src/aman.py | 678 +++----------------------------- src/config.py | 41 ++ src/constants.py | 1 - src/desktop.py | 2 +- src/desktop_wayland.py | 2 +- src/desktop_x11.py | 107 +---- src/engine.py | 86 ---- src/notify.py | 40 -- src/pipelines_runtime.py | 97 ----- src/vocabulary.py | 161 +++++++- tests/test_aman.py | 209 +--------- tests/test_config.py | 89 ++++- tests/test_engine.py | 92 ----- tests/test_pipelines_runtime.py | 109 ----- tests/test_vocabulary.py | 31 +- 18 files changed, 399 insertions(+), 1523 deletions(-) delete mode 100644 pipelines.example.py delete mode 100644 src/engine.py delete mode 100644 src/notify.py delete mode 100644 src/pipelines_runtime.py delete mode 100644 tests/test_engine.py delete mode 100644 tests/test_pipelines_runtime.py diff --git a/README.md b/README.md index 55cca01..c7510b1 100644 --- a/README.md +++ b/README.md @@ -80,7 +80,8 @@ Create `~/.config/aman/config.json` (or let `aman` create it automatically on fi { "from": "docker", "to": "Docker" } ], "terms": ["Systemd", "Kubernetes"] - } + }, + "domain_inference": { "enabled": true } } ``` @@ -91,9 +92,6 @@ Hotkey notes: - Use one key plus optional modifiers (for example `Cmd+m`, `Super+m`, `Ctrl+space`). - `Super` and `Cmd` are equivalent aliases for the same modifier. -- Invalid hotkey syntax in config prevents startup/reload. -- When `~/.config/aman/pipelines.py` exists, hotkeys come from `HOTKEY_PIPELINES`. -- `daemon.hotkey` is used as the fallback/default hotkey only when no pipelines file is present. AI cleanup is always enabled and uses the locked local Llama-3.2-3B GGUF model downloaded to `~/.cache/aman/models/` during daemon initialization. @@ -109,6 +107,11 @@ Vocabulary correction: - Wildcards are intentionally rejected (`*`, `?`, `[`, `]`, `{`, `}`) to avoid ambiguous rules. - Rules are deduplicated case-insensitively; conflicting replacements are rejected. +Domain inference: + +- Domain context is advisory only and is used to improve cleanup prompts. +- When confidence is low, it falls back to `general` context. + STT hinting: - Vocabulary is passed to Whisper as `hotwords`/`initial_prompt` only when those @@ -131,12 +134,6 @@ systemctl --user enable --now aman - Press it again to stop and run STT. - Press `Esc` while recording to cancel without processing. - Transcript contents are logged only when `-v/--verbose` is used. -- Config changes are hot-reloaded automatically (polled every 1 second). -- `~/.config/aman/pipelines.py` changes are hot-reloaded automatically (polled every 1 second). -- Send `SIGHUP` to force an immediate reload of config and pipelines: - `systemctl --user kill -s HUP aman` (or send `HUP` to the process directly). -- Reloads are applied when the daemon is idle; invalid updates are rejected and the last valid config stays active. -- Reload success/failure is logged, and desktop notifications are shown when available. Wayland note: @@ -152,77 +149,6 @@ AI processing: - Local llama.cpp model only (no remote provider configuration). -## Pipelines API - -`aman` is split into: - -- shell daemon: hotkeys, recording/cancel, and desktop injection -- pipeline engine: `lib.transcribe(...)` and `lib.llm(...)` -- pipeline implementation: Python callables mapped per hotkey - -Pipeline file path: - -- `~/.config/aman/pipelines.py` -- You can start from [`pipelines.example.py`](./pipelines.example.py). -- If `pipelines.py` is missing, `aman` uses a built-in reference pipeline bound to `daemon.hotkey`. -- If `pipelines.py` exists but is invalid, startup fails fast. -- Pipelines are hot-reloaded automatically when the module file changes. -- Send `SIGHUP` to force an immediate reload of both config and pipelines. - -Expected module exports: - -```python -HOTKEY_PIPELINES = { - "Super+m": my_pipeline, - "Super+Shift+m": caps_pipeline, -} - -PIPELINE_OPTIONS = { - "Super+Shift+m": {"failure_policy": "strict"}, # optional -} -``` - -Pipeline callable signature: - -```python -def my_pipeline(audio, lib) -> str: - text = lib.transcribe(audio) - context = lib.llm( - system_prompt="context system prompt", - user_prompt=f"Transcript: {text}", - ) - out = lib.llm( - system_prompt="amanuensis prompt", - user_prompt=f"context={context}\ntext={text}", - ) - return out -``` - -`lib` API: - -- `lib.transcribe(audio, hints=None, whisper_opts=None) -> str` -- `lib.llm(system_prompt=..., user_prompt=..., llm_opts=None) -> str` - -Failure policy options: - -- `best_effort` (default): pipeline errors return empty output -- `strict`: pipeline errors abort the current run - -Validation: - -- `HOTKEY_PIPELINES` must be a non-empty dictionary. -- Every hotkey key must be a non-empty string. -- Every pipeline value must be callable. -- `PIPELINE_OPTIONS` must be a dictionary when provided. - -Reference behavior: - -- The built-in fallback pipeline (used when `pipelines.py` is missing) uses `lib.llm(...)` twice: - - first to infer context - - second to run the amanuensis rewrite -- The second pass requests JSON output and expects `{"cleaned_text": "..."}`. -- Deterministic dictionary replacements are then applied as part of that reference implementation. - Control: ```bash diff --git a/pipelines.example.py b/pipelines.example.py deleted file mode 100644 index d5d92f4..0000000 --- a/pipelines.example.py +++ /dev/null @@ -1,50 +0,0 @@ -import json - - -CONTEXT_PROMPT = ( - "Return a concise plain-text context hint (max 12 words) " - "for the transcript domain and style." -) - -AMANUENSIS_PROMPT = ( - "Rewrite the transcript into clean prose without changing intent. " - "Return only the final text." -) - - -def default_pipeline(audio, lib) -> str: - text = lib.transcribe(audio) - if not text: - return "" - - context = lib.llm( - system_prompt=CONTEXT_PROMPT, - user_prompt=json.dumps({"transcript": text}, ensure_ascii=False), - llm_opts={"temperature": 0.0}, - ).strip() - - output = lib.llm( - system_prompt=AMANUENSIS_PROMPT, - user_prompt=json.dumps( - {"transcript": text, "context": context}, - ensure_ascii=False, - ), - llm_opts={"temperature": 0.0}, - ) - return output.strip() - - -def caps_pipeline(audio, lib) -> str: - return lib.transcribe(audio).upper() - - -HOTKEY_PIPELINES = { - "Super+m": default_pipeline, - "Super+Shift+m": caps_pipeline, -} - - -PIPELINE_OPTIONS = { - "Super+m": {"failure_policy": "best_effort"}, - "Super+Shift+m": {"failure_policy": "strict"}, -} diff --git a/src/aiprocess.py b/src/aiprocess.py index d423271..dec2e4b 100644 --- a/src/aiprocess.py +++ b/src/aiprocess.py @@ -76,41 +76,18 @@ class LlamaProcessor: if cleaned_dictionary: request_payload["dictionary"] = cleaned_dictionary - content = self.chat( - system_prompt=SYSTEM_PROMPT, - user_prompt=json.dumps(request_payload, ensure_ascii=False), - llm_opts={"temperature": 0.0, "response_format": {"type": "json_object"}}, - ) - return _extract_cleaned_text_from_raw(content) - - def chat( - self, - *, - system_prompt: str, - user_prompt: str, - llm_opts: dict[str, Any] | None = None, - ) -> str: - opts = dict(llm_opts or {}) - temperature = float(opts.pop("temperature", 0.0)) - response_format = opts.pop("response_format", None) - if opts: - unknown = ", ".join(sorted(opts.keys())) - raise ValueError(f"unsupported llm options: {unknown}") - kwargs: dict[str, Any] = { "messages": [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": user_prompt}, + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": json.dumps(request_payload, ensure_ascii=False)}, ], - "temperature": temperature, + "temperature": 0.0, } - if response_format is not None and _supports_response_format( - self.client.create_chat_completion - ): - kwargs["response_format"] = response_format + if _supports_response_format(self.client.create_chat_completion): + kwargs["response_format"] = {"type": "json_object"} response = self.client.create_chat_completion(**kwargs) - return _extract_chat_text(response) + return _extract_cleaned_text(response) def ensure_model(): @@ -171,10 +148,6 @@ def _extract_chat_text(payload: Any) -> str: def _extract_cleaned_text(payload: Any) -> str: raw = _extract_chat_text(payload) - return _extract_cleaned_text_from_raw(raw) - - -def _extract_cleaned_text_from_raw(raw: str) -> str: try: parsed = json.loads(raw) except json.JSONDecodeError as exc: diff --git a/src/aman.py b/src/aman.py index cb79903..0ffbc20 100755 --- a/src/aman.py +++ b/src/aman.py @@ -3,7 +3,6 @@ from __future__ import annotations import argparse import errno -import hashlib import inspect import json import logging @@ -13,21 +12,12 @@ import sys import threading import time from pathlib import Path -from typing import Any, Callable +from typing import Any -from aiprocess import LlamaProcessor, SYSTEM_PROMPT +from aiprocess import LlamaProcessor from config import Config, load, redacted_dict -from constants import ( - CONFIG_RELOAD_POLL_INTERVAL_SEC, - DEFAULT_CONFIG_PATH, - RECORD_TIMEOUT_SEC, - STT_LANGUAGE, -) +from constants import RECORD_TIMEOUT_SEC, STT_LANGUAGE from desktop import get_desktop_adapter -from engine import Engine, PipelineBinding, PipelineLib, PipelineOptions -from notify import DesktopNotifier -from pipelines_runtime import DEFAULT_PIPELINES_PATH, fingerprint as pipelines_fingerprint -from pipelines_runtime import load_bindings from recorder import start_recording as start_audio_recording from recorder import stop_recording as stop_audio_recording from vocabulary import VocabularyEngine @@ -65,52 +55,6 @@ def _compute_type(device: str) -> str: return "int8" -def _replacement_view(cfg: Config) -> list[tuple[str, str]]: - return [(item.source, item.target) for item in cfg.vocabulary.replacements] - - -def _config_fields_changed(current: Config, candidate: Config) -> list[str]: - changed: list[str] = [] - if current.daemon.hotkey != candidate.daemon.hotkey: - changed.append("daemon.hotkey") - if current.recording.input != candidate.recording.input: - changed.append("recording.input") - if current.stt.model != candidate.stt.model: - changed.append("stt.model") - if current.stt.device != candidate.stt.device: - changed.append("stt.device") - if current.injection.backend != candidate.injection.backend: - changed.append("injection.backend") - if ( - current.injection.remove_transcription_from_clipboard - != candidate.injection.remove_transcription_from_clipboard - ): - changed.append("injection.remove_transcription_from_clipboard") - if _replacement_view(current) != _replacement_view(candidate): - changed.append("vocabulary.replacements") - if current.vocabulary.terms != candidate.vocabulary.terms: - changed.append("vocabulary.terms") - return changed - - -CONTEXT_SYSTEM_PROMPT = ( - "You infer concise writing context from transcript text.\n" - "Return a short plain-text hint (max 12 words) that describes the likely domain and style.\n" - "Do not add punctuation, quotes, XML tags, or markdown." -) - - -def _extract_cleaned_text_from_llm_raw(raw: str) -> str: - parsed = json.loads(raw) - if isinstance(parsed, str): - return parsed - if isinstance(parsed, dict): - cleaned_text = parsed.get("cleaned_text") - if isinstance(cleaned_text, str): - return cleaned_text - raise RuntimeError("unexpected ai output format: missing cleaned_text") - - class Daemon: def __init__(self, cfg: Config, desktop, *, verbose: bool = False): self.cfg = cfg @@ -122,34 +66,17 @@ class Daemon: self.stream = None self.record = None self.timer: threading.Timer | None = None - self._config_apply_in_progress = False - self._active_hotkey: str | None = None - self._pending_worker_hotkey: str | None = None - self._pipeline_bindings: dict[str, PipelineBinding] = {} - - self.model = _build_whisper_model(cfg.stt.model, cfg.stt.device) + self.model = _build_whisper_model( + cfg.stt.model, + cfg.stt.device, + ) logging.info("initializing ai processor") self.ai_processor = LlamaProcessor(verbose=self.verbose) logging.info("ai processor ready") self.log_transcript = verbose - self.vocabulary = VocabularyEngine(cfg.vocabulary) + self.vocabulary = VocabularyEngine(cfg.vocabulary, cfg.domain_inference) self._stt_hint_kwargs_cache: dict[str, Any] | None = None - self.lib = PipelineLib( - transcribe_fn=self._transcribe_with_options, - llm_fn=self._llm_with_options, - ) - self.engine = Engine(self.lib) - self.set_pipeline_bindings( - { - cfg.daemon.hotkey: PipelineBinding( - hotkey=cfg.daemon.hotkey, - handler=self.build_reference_pipeline(), - options=PipelineOptions(failure_policy="best_effort"), - ) - } - ) - def set_state(self, state: str): with self.lock: prev = self.state @@ -166,74 +93,16 @@ class Daemon: def request_shutdown(self): self._shutdown_requested.set() - def current_hotkeys(self) -> list[str]: - with self.lock: - return list(self._pipeline_bindings.keys()) - - def set_pipeline_bindings(self, bindings: dict[str, PipelineBinding]) -> None: - if not bindings: - raise ValueError("at least one pipeline binding is required") - with self.lock: - self._pipeline_bindings = dict(bindings) - if self._active_hotkey and self._active_hotkey not in self._pipeline_bindings: - self._active_hotkey = None - - def build_hotkey_callbacks( - self, - hotkeys: list[str], - *, - dry_run: bool, - ) -> dict[str, Callable[[], None]]: - callbacks: dict[str, Callable[[], None]] = {} - for hotkey in hotkeys: - if dry_run: - callbacks[hotkey] = (lambda hk=hotkey: logging.info("hotkey pressed: %s (dry-run)", hk)) - else: - callbacks[hotkey] = (lambda hk=hotkey: self.handle_hotkey(hk)) - return callbacks - - def apply_pipeline_bindings( - self, - bindings: dict[str, PipelineBinding], - callbacks: dict[str, Callable[[], None]], - ) -> tuple[str, str]: - with self.lock: - if self._shutdown_requested.is_set(): - return "deferred", "shutdown in progress" - if self._config_apply_in_progress: - return "deferred", "reload in progress" - if self.state != State.IDLE: - return "deferred", f"daemon is busy ({self.state})" - self._config_apply_in_progress = True - try: - self.desktop.set_hotkeys(callbacks) - with self.lock: - if self._shutdown_requested.is_set(): - return "deferred", "shutdown in progress" - if self.state != State.IDLE: - return "deferred", f"daemon is busy ({self.state})" - self._pipeline_bindings = dict(bindings) - return "applied", "" - except Exception as exc: - return "error", str(exc) - finally: - with self.lock: - self._config_apply_in_progress = False - - def handle_hotkey(self, hotkey: str): + def toggle(self): should_stop = False with self.lock: if self._shutdown_requested.is_set(): logging.info("shutdown in progress, trigger ignored") return - if self._config_apply_in_progress: - logging.info("reload in progress, trigger ignored") - return if self.state == State.IDLE: - self._active_hotkey = hotkey self._start_recording_locked() return - if self.state == State.RECORDING and self._active_hotkey == hotkey: + if self.state == State.RECORDING: should_stop = True else: logging.info("busy (%s), trigger ignored", self.state) @@ -244,9 +113,6 @@ class Daemon: if self.state != State.IDLE: logging.info("busy (%s), trigger ignored", self.state) return - if self._config_apply_in_progress: - logging.info("reload in progress, trigger ignored") - return try: stream, record = start_audio_recording(self.cfg.recording.input) except Exception as exc: @@ -267,17 +133,10 @@ class Daemon: def _timeout_stop(self): self.stop_recording(trigger="timeout") - def _start_stop_worker( - self, - stream: Any, - record: Any, - trigger: str, - process_audio: bool, - ): - active_hotkey = self._pending_worker_hotkey or self.cfg.daemon.hotkey + def _start_stop_worker(self, stream: Any, record: Any, trigger: str, process_audio: bool): threading.Thread( target=self._stop_and_process, - args=(stream, record, trigger, process_audio, active_hotkey), + args=(stream, record, trigger, process_audio), daemon=True, ).start() @@ -286,8 +145,6 @@ class Daemon: return None stream = self.stream record = self.record - active_hotkey = self._active_hotkey or self.cfg.daemon.hotkey - self._active_hotkey = None self.stream = None self.record = None if self.timer: @@ -301,17 +158,9 @@ class Daemon: logging.warning("recording resources are unavailable during stop") self.state = State.IDLE return None - return stream, record, active_hotkey + return stream, record - def _stop_and_process( - self, - stream: Any, - record: Any, - trigger: str, - process_audio: bool, - active_hotkey: str | None = None, - ): - hotkey = active_hotkey or self.cfg.daemon.hotkey + def _stop_and_process(self, stream: Any, record: Any, trigger: str, process_audio: bool): logging.info("stopping recording (%s)", trigger) try: audio = stop_audio_recording(stream, record) @@ -330,17 +179,44 @@ class Daemon: return try: - logging.info("pipeline started (%s)", hotkey) - text = self._run_pipeline(audio, hotkey).strip() + self.set_state(State.STT) + logging.info("stt started") + text = self._transcribe(audio) except Exception as exc: - logging.error("pipeline failed: %s", exc) + logging.error("stt failed: %s", exc) self.set_state(State.IDLE) return + text = (text or "").strip() if not text: self.set_state(State.IDLE) return + if self.log_transcript: + logging.debug("stt: %s", text) + else: + logging.info("stt produced %d chars", len(text)) + + domain = self.vocabulary.infer_domain(text) + if not self._shutdown_requested.is_set(): + self.set_state(State.PROCESSING) + logging.info("ai processing started") + try: + processor = self._get_ai_processor() + ai_text = processor.process( + text, + lang=STT_LANGUAGE, + dictionary_context=self.vocabulary.build_ai_dictionary_context(), + domain_name=domain.name, + domain_confidence=domain.confidence, + ) + if ai_text and ai_text.strip(): + text = ai_text.strip() + except Exception as exc: + logging.error("ai process failed: %s", exc) + + text = self.vocabulary.apply_deterministic_replacements(text).strip() + if self.log_transcript: logging.debug("processed: %s", text) else: @@ -366,23 +242,13 @@ class Daemon: finally: self.set_state(State.IDLE) - def _run_pipeline(self, audio: Any, hotkey: str) -> str: - with self.lock: - binding = self._pipeline_bindings.get(hotkey) - if binding is None and self._pipeline_bindings: - binding = next(iter(self._pipeline_bindings.values())) - if binding is None: - return "" - return self.engine.run(binding, audio) - def stop_recording(self, *, trigger: str = "user", process_audio: bool = True): payload = None with self.lock: payload = self._begin_stop_locked() if payload is None: return - stream, record, active_hotkey = payload - self._pending_worker_hotkey = active_hotkey + stream, record = payload self._start_stop_worker(stream, record, trigger, process_audio) def cancel_recording(self): @@ -404,129 +270,40 @@ class Daemon: time.sleep(0.05) return self.get_state() == State.IDLE - def build_reference_pipeline(self) -> Callable[[Any, PipelineLib], str]: - def _reference_pipeline(audio: Any, lib: PipelineLib) -> str: - text = (lib.transcribe(audio) or "").strip() - if not text: - return "" - dictionary_context = self.vocabulary.build_ai_dictionary_context() - context_prompt = { - "transcript": text, - "dictionary": dictionary_context, - } - context = ( - lib.llm( - system_prompt=CONTEXT_SYSTEM_PROMPT, - user_prompt=json.dumps(context_prompt, ensure_ascii=False), - llm_opts={"temperature": 0.0}, - ) - .strip() - ) - payload: dict[str, Any] = { - "language": STT_LANGUAGE, - "transcript": text, - "context": context, - } - if dictionary_context: - payload["dictionary"] = dictionary_context - ai_raw = lib.llm( - system_prompt=SYSTEM_PROMPT, - user_prompt=json.dumps(payload, ensure_ascii=False), - llm_opts={"temperature": 0.0, "response_format": {"type": "json_object"}}, - ) - cleaned_text = _extract_cleaned_text_from_llm_raw(ai_raw) - return self.vocabulary.apply_deterministic_replacements(cleaned_text).strip() - - return _reference_pipeline - def _transcribe(self, audio) -> str: - return self._transcribe_with_options(audio) - - def _transcribe_with_options( - self, - audio: Any, - *, - hints: list[str] | None = None, - whisper_opts: dict[str, Any] | None = None, - ) -> str: kwargs: dict[str, Any] = { "language": STT_LANGUAGE, "vad_filter": True, } - kwargs.update(self._stt_hint_kwargs(hints=hints)) - if whisper_opts: - kwargs.update(whisper_opts) + kwargs.update(self._stt_hint_kwargs()) segments, _info = self.model.transcribe(audio, **kwargs) parts = [] for seg in segments: text = (seg.text or "").strip() if text: parts.append(text) - out = " ".join(parts).strip() - if self.log_transcript: - logging.debug("stt: %s", out) - else: - logging.info("stt produced %d chars", len(out)) - return out - - def _llm_with_options( - self, - *, - system_prompt: str, - user_prompt: str, - llm_opts: dict[str, Any] | None = None, - ) -> str: - if self.get_state() != State.PROCESSING: - self.set_state(State.PROCESSING) - logging.info("llm processing started") - processor = self._get_ai_processor() - return processor.chat( - system_prompt=system_prompt, - user_prompt=user_prompt, - llm_opts=llm_opts, - ) + return " ".join(parts).strip() def _get_ai_processor(self) -> LlamaProcessor: if self.ai_processor is None: raise RuntimeError("ai processor is not initialized") return self.ai_processor - def _stt_hint_kwargs(self, *, hints: list[str] | None = None) -> dict[str, Any]: - if not hints and self._stt_hint_kwargs_cache is not None: + def _stt_hint_kwargs(self) -> dict[str, Any]: + if self._stt_hint_kwargs_cache is not None: return self._stt_hint_kwargs_cache hotwords, initial_prompt = self.vocabulary.build_stt_hints() - extra_hints = [item.strip() for item in (hints or []) if isinstance(item, str) and item.strip()] - if extra_hints: - words = [item.strip() for item in hotwords.split(",") if item.strip()] - words.extend(extra_hints) - deduped: list[str] = [] - seen: set[str] = set() - for item in words: - key = item.casefold() - if key in seen: - continue - seen.add(key) - deduped.append(item) - merged = ", ".join(deduped) - hotwords = merged[:1024] - if hotwords: - initial_prompt = f"Preferred vocabulary: {hotwords}"[:600] - if not hotwords and not initial_prompt: - if not hints: - self._stt_hint_kwargs_cache = {} - return self._stt_hint_kwargs_cache - return {} + self._stt_hint_kwargs_cache = {} + return self._stt_hint_kwargs_cache try: signature = inspect.signature(self.model.transcribe) except (TypeError, ValueError): logging.debug("stt signature inspection failed; skipping hints") - if not hints: - self._stt_hint_kwargs_cache = {} - return self._stt_hint_kwargs_cache - return {} + self._stt_hint_kwargs_cache = {} + return self._stt_hint_kwargs_cache params = signature.parameters kwargs: dict[str, Any] = {} @@ -536,288 +313,8 @@ class Daemon: kwargs["initial_prompt"] = initial_prompt if not kwargs: logging.debug("stt hint arguments are not supported by this whisper runtime") - if not hints: - self._stt_hint_kwargs_cache = kwargs - return self._stt_hint_kwargs_cache - return kwargs - - def try_apply_config(self, candidate: Config) -> tuple[str, list[str], str]: - with self.lock: - if self._shutdown_requested.is_set(): - return "deferred", [], "shutdown in progress" - if self._config_apply_in_progress: - return "deferred", [], "config reload already in progress" - if self.state != State.IDLE: - return "deferred", [], f"daemon is busy ({self.state})" - self._config_apply_in_progress = True - current_cfg = self.cfg - - try: - changed = _config_fields_changed(current_cfg, candidate) - if not changed: - return "applied", [], "" - - next_model = None - if ( - current_cfg.stt.model != candidate.stt.model - or current_cfg.stt.device != candidate.stt.device - ): - try: - next_model = _build_whisper_model(candidate.stt.model, candidate.stt.device) - except Exception as exc: - return "error", [], f"stt model reload failed: {exc}" - - try: - next_vocab = VocabularyEngine(candidate.vocabulary) - except Exception as exc: - return "error", [], f"vocabulary reload failed: {exc}" - - with self.lock: - if self._shutdown_requested.is_set(): - return "deferred", [], "shutdown in progress" - if self.state != State.IDLE: - return "deferred", [], f"daemon is busy ({self.state})" - self.cfg = candidate - if next_model is not None: - self.model = next_model - self.vocabulary = next_vocab - self._stt_hint_kwargs_cache = None - return "applied", changed, "" - finally: - with self.lock: - self._config_apply_in_progress = False - - -class ConfigReloader: - def __init__( - self, - *, - daemon: Daemon, - config_path: Path, - notifier: DesktopNotifier | None = None, - poll_interval_sec: float = CONFIG_RELOAD_POLL_INTERVAL_SEC, - ): - self.daemon = daemon - self.config_path = config_path - self.notifier = notifier - self.poll_interval_sec = poll_interval_sec - self._stop_event = threading.Event() - self._request_lock = threading.Lock() - self._pending_reload_reason = "" - self._pending_force = False - self._last_seen_fingerprint = self._fingerprint() - self._thread: threading.Thread | None = None - - def start(self) -> None: - if self._thread is not None: - return - self._thread = threading.Thread(target=self._run, daemon=True, name="config-reloader") - self._thread.start() - - def stop(self, timeout: float = 2.0) -> None: - self._stop_event.set() - if self._thread is not None: - self._thread.join(timeout=timeout) - self._thread = None - - def request_reload(self, reason: str, *, force: bool = False) -> None: - with self._request_lock: - if self._pending_reload_reason: - self._pending_reload_reason = f"{self._pending_reload_reason}; {reason}" - else: - self._pending_reload_reason = reason - self._pending_force = self._pending_force or force - logging.debug("config reload requested: %s (force=%s)", reason, force) - - def tick(self) -> None: - self._detect_file_change() - self._apply_if_pending() - - def _run(self) -> None: - while not self._stop_event.wait(self.poll_interval_sec): - self.tick() - - def _detect_file_change(self) -> None: - fingerprint = self._fingerprint() - if fingerprint == self._last_seen_fingerprint: - return - self._last_seen_fingerprint = fingerprint - self.request_reload("file change detected", force=False) - - def _apply_if_pending(self) -> None: - with self._request_lock: - reason = self._pending_reload_reason - force_reload = self._pending_force - if not reason: - return - - if self.daemon.get_state() != State.IDLE: - logging.debug("config reload deferred; daemon state=%s", self.daemon.get_state()) - return - - try: - candidate = load(str(self.config_path)) - except Exception as exc: - self._clear_pending() - msg = f"config reload failed ({reason}): {exc}" - logging.error(msg) - self._notify("Config Reload Failed", str(exc), error=True) - return - - status, changed, error = self.daemon.try_apply_config(candidate) - if status == "deferred": - logging.debug("config reload deferred during apply: %s", error) - return - self._clear_pending() - if status == "error": - msg = f"config reload failed ({reason}): {error}" - logging.error(msg) - self._notify("Config Reload Failed", error, error=True) - return - - if changed: - msg = ", ".join(changed) - logging.info("config reloaded (%s): %s", reason, msg) - self._notify("Config Reloaded", f"Applied: {msg}", error=False) - return - if force_reload: - logging.info("config reloaded (%s): no effective changes", reason) - self._notify("Config Reloaded", "No effective changes", error=False) - - def _clear_pending(self) -> None: - with self._request_lock: - self._pending_reload_reason = "" - self._pending_force = False - - def _fingerprint(self) -> str | None: - try: - data = self.config_path.read_bytes() - return hashlib.sha256(data).hexdigest() - except FileNotFoundError: - return None - except OSError as exc: - logging.warning("config fingerprint failed (%s): %s", self.config_path, exc) - return None - - def _notify(self, title: str, body: str, *, error: bool) -> None: - if self.notifier is None: - return - sent = self.notifier.send(title, body, error=error) - if not sent: - logging.debug("desktop notifications unavailable: %s", title) - - -class PipelineReloader: - def __init__( - self, - *, - daemon: Daemon, - pipelines_path: Path, - notifier: DesktopNotifier | None = None, - dry_run: bool = False, - poll_interval_sec: float = CONFIG_RELOAD_POLL_INTERVAL_SEC, - ): - self.daemon = daemon - self.pipelines_path = pipelines_path - self.notifier = notifier - self.dry_run = dry_run - self.poll_interval_sec = poll_interval_sec - self._stop_event = threading.Event() - self._request_lock = threading.Lock() - self._pending_reload_reason = "" - self._pending_force = False - self._last_seen_fingerprint = pipelines_fingerprint(self.pipelines_path) - self._thread: threading.Thread | None = None - - def start(self) -> None: - if self._thread is not None: - return - self._thread = threading.Thread(target=self._run, daemon=True, name="pipeline-reloader") - self._thread.start() - - def stop(self, timeout: float = 2.0) -> None: - self._stop_event.set() - if self._thread is not None: - self._thread.join(timeout=timeout) - self._thread = None - - def request_reload(self, reason: str, *, force: bool = False) -> None: - with self._request_lock: - if self._pending_reload_reason: - self._pending_reload_reason = f"{self._pending_reload_reason}; {reason}" - else: - self._pending_reload_reason = reason - self._pending_force = self._pending_force or force - logging.debug("pipeline reload requested: %s (force=%s)", reason, force) - - def tick(self) -> None: - self._detect_file_change() - self._apply_if_pending() - - def _run(self) -> None: - while not self._stop_event.wait(self.poll_interval_sec): - self.tick() - - def _detect_file_change(self) -> None: - fingerprint = pipelines_fingerprint(self.pipelines_path) - if fingerprint == self._last_seen_fingerprint: - return - self._last_seen_fingerprint = fingerprint - self.request_reload("file change detected", force=False) - - def _apply_if_pending(self) -> None: - with self._request_lock: - reason = self._pending_reload_reason - force_reload = self._pending_force - if not reason: - return - if self.daemon.get_state() != State.IDLE: - logging.debug("pipeline reload deferred; daemon state=%s", self.daemon.get_state()) - return - - try: - bindings = load_bindings( - path=self.pipelines_path, - default_hotkey=self.daemon.cfg.daemon.hotkey, - default_handler_factory=self.daemon.build_reference_pipeline, - ) - callbacks = self.daemon.build_hotkey_callbacks( - list(bindings.keys()), - dry_run=self.dry_run, - ) - except Exception as exc: - self._clear_pending() - logging.error("pipeline reload failed (%s): %s", reason, exc) - self._notify("Pipelines Reload Failed", str(exc), error=True) - return - - status, error = self.daemon.apply_pipeline_bindings(bindings, callbacks) - if status == "deferred": - logging.debug("pipeline reload deferred during apply: %s", error) - return - self._clear_pending() - if status == "error": - logging.error("pipeline reload failed (%s): %s", reason, error) - self._notify("Pipelines Reload Failed", error, error=True) - return - - hotkeys = ", ".join(sorted(bindings.keys())) - logging.info("pipelines reloaded (%s): %s", reason, hotkeys) - self._notify("Pipelines Reloaded", f"Hotkeys: {hotkeys}", error=False) - if force_reload and not bindings: - logging.info("pipelines reloaded (%s): no hotkeys", reason) - - def _clear_pending(self) -> None: - with self._request_lock: - self._pending_reload_reason = "" - self._pending_force = False - - def _notify(self, title: str, body: str, *, error: bool) -> None: - if self.notifier is None: - return - sent = self.notifier.send(title, body, error=error) - if not sent: - logging.debug("desktop notifications unavailable: %s", title) + self._stt_hint_kwargs_cache = kwargs + return self._stt_hint_kwargs_cache def _read_lock_pid(lock_file) -> str: @@ -869,49 +366,19 @@ def main(): level=logging.DEBUG if args.verbose else logging.INFO, format="aman: %(asctime)s %(levelname)s %(message)s", ) - config_path = Path(args.config) if args.config else DEFAULT_CONFIG_PATH - cfg = load(str(config_path)) + cfg = load(args.config) _LOCK_HANDLE = _lock_single_instance() logging.info("hotkey: %s", cfg.daemon.hotkey) logging.info( "config (%s):\n%s", - str(config_path), + args.config or str(Path.home() / ".config" / "aman" / "config.json"), json.dumps(redacted_dict(cfg), indent=2), ) - config_reloader = None - pipeline_reloader = None try: desktop = get_desktop_adapter() daemon = Daemon(cfg, desktop, verbose=args.verbose) - notifier = DesktopNotifier(app_name="aman") - config_reloader = ConfigReloader( - daemon=daemon, - config_path=config_path, - notifier=notifier, - poll_interval_sec=CONFIG_RELOAD_POLL_INTERVAL_SEC, - ) - pipelines_path = DEFAULT_PIPELINES_PATH - initial_bindings = load_bindings( - path=pipelines_path, - default_hotkey=cfg.daemon.hotkey, - default_handler_factory=daemon.build_reference_pipeline, - ) - initial_callbacks = daemon.build_hotkey_callbacks( - list(initial_bindings.keys()), - dry_run=args.dry_run, - ) - status, error = daemon.apply_pipeline_bindings(initial_bindings, initial_callbacks) - if status != "applied": - raise RuntimeError(f"pipeline setup failed: {error}") - pipeline_reloader = PipelineReloader( - daemon=daemon, - pipelines_path=pipelines_path, - notifier=notifier, - dry_run=args.dry_run, - poll_interval_sec=CONFIG_RELOAD_POLL_INTERVAL_SEC, - ) except Exception as exc: logging.error("startup failed: %s", exc) raise SystemExit(1) @@ -923,10 +390,6 @@ def main(): return shutdown_once.set() logging.info("%s, shutting down", reason) - if config_reloader is not None: - config_reloader.stop(timeout=2.0) - if pipeline_reloader is not None: - pipeline_reloader.stop(timeout=2.0) if not daemon.shutdown(timeout=5.0): logging.warning("timed out waiting for idle state during shutdown") desktop.request_quit() @@ -934,30 +397,15 @@ def main(): def handle_signal(_sig, _frame): threading.Thread(target=shutdown, args=("signal received",), daemon=True).start() - def handle_reload_signal(_sig, _frame): - if shutdown_once.is_set(): - return - if config_reloader is not None: - threading.Thread( - target=lambda: config_reloader.request_reload("signal SIGHUP", force=True), - daemon=True, - ).start() - if pipeline_reloader is not None: - threading.Thread( - target=lambda: pipeline_reloader.request_reload("signal SIGHUP", force=True), - daemon=True, - ).start() - signal.signal(signal.SIGINT, handle_signal) signal.signal(signal.SIGTERM, handle_signal) - signal.signal(signal.SIGHUP, handle_reload_signal) try: + desktop.start_hotkey_listener( + cfg.daemon.hotkey, + lambda: logging.info("hotkey pressed (dry-run)") if args.dry_run else daemon.toggle(), + ) desktop.start_cancel_listener(lambda: daemon.cancel_recording()) - if config_reloader is not None: - config_reloader.start() - if pipeline_reloader is not None: - pipeline_reloader.start() except Exception as exc: logging.error("hotkey setup failed: %s", exc) raise SystemExit(1) @@ -965,10 +413,6 @@ def main(): try: desktop.run_tray(daemon.get_state, lambda: shutdown("quit requested")) finally: - if config_reloader is not None: - config_reloader.stop(timeout=2.0) - if pipeline_reloader is not None: - pipeline_reloader.stop(timeout=2.0) daemon.shutdown(timeout=1.0) diff --git a/src/config.py b/src/config.py index 2557c4f..6236966 100644 --- a/src/config.py +++ b/src/config.py @@ -51,6 +51,11 @@ class VocabularyConfig: terms: list[str] = field(default_factory=list) +@dataclass +class DomainInferenceConfig: + enabled: bool = True + + @dataclass class Config: daemon: DaemonConfig = field(default_factory=DaemonConfig) @@ -58,6 +63,7 @@ class Config: stt: SttConfig = field(default_factory=SttConfig) injection: InjectionConfig = field(default_factory=InjectionConfig) vocabulary: VocabularyConfig = field(default_factory=VocabularyConfig) + domain_inference: DomainInferenceConfig = field(default_factory=DomainInferenceConfig) def load(path: str | None) -> Config: @@ -118,8 +124,20 @@ def validate(cfg: Config) -> None: cfg.vocabulary.replacements = _validate_replacements(cfg.vocabulary.replacements) cfg.vocabulary.terms = _validate_terms(cfg.vocabulary.terms) + if not isinstance(cfg.domain_inference.enabled, bool): + raise ValueError("domain_inference.enabled must be boolean") + def _from_dict(data: dict[str, Any], cfg: Config) -> Config: + if "logging" in data: + raise ValueError("logging section is no longer supported; use -v/--verbose") + if "log_transcript" in data: + raise ValueError("log_transcript is no longer supported; use -v/--verbose") + if "ai" in data: + raise ValueError("ai section is no longer supported") + if "ai_enabled" in data: + raise ValueError("ai_enabled is no longer supported") + has_sections = any( key in data for key in ( @@ -128,6 +146,7 @@ def _from_dict(data: dict[str, Any], cfg: Config) -> Config: "stt", "injection", "vocabulary", + "domain_inference", ) ) if has_sections: @@ -136,6 +155,7 @@ def _from_dict(data: dict[str, Any], cfg: Config) -> Config: stt = _ensure_dict(data.get("stt"), "stt") injection = _ensure_dict(data.get("injection"), "injection") vocabulary = _ensure_dict(data.get("vocabulary"), "vocabulary") + domain_inference = _ensure_dict(data.get("domain_inference"), "domain_inference") if "hotkey" in daemon: cfg.daemon.hotkey = _as_nonempty_str(daemon["hotkey"], "daemon.hotkey") @@ -156,7 +176,28 @@ def _from_dict(data: dict[str, Any], cfg: Config) -> Config: cfg.vocabulary.replacements = _as_replacements(vocabulary["replacements"]) if "terms" in vocabulary: cfg.vocabulary.terms = _as_terms(vocabulary["terms"]) + if "max_rules" in vocabulary: + raise ValueError("vocabulary.max_rules is no longer supported") + if "max_terms" in vocabulary: + raise ValueError("vocabulary.max_terms is no longer supported") + if "enabled" in domain_inference: + cfg.domain_inference.enabled = _as_bool( + domain_inference["enabled"], "domain_inference.enabled" + ) + if "mode" in domain_inference: + raise ValueError("domain_inference.mode is no longer supported") return cfg + + if "hotkey" in data: + cfg.daemon.hotkey = _as_nonempty_str(data["hotkey"], "hotkey") + if "input" in data: + cfg.recording.input = _as_recording_input(data["input"]) + if "whisper_model" in data: + cfg.stt.model = _as_nonempty_str(data["whisper_model"], "whisper_model") + if "whisper_device" in data: + cfg.stt.device = _as_nonempty_str(data["whisper_device"], "whisper_device") + if "injection_backend" in data: + cfg.injection.backend = _as_nonempty_str(data["injection_backend"], "injection_backend") return cfg diff --git a/src/constants.py b/src/constants.py index 7dd942f..df06122 100644 --- a/src/constants.py +++ b/src/constants.py @@ -2,7 +2,6 @@ from pathlib import Path DEFAULT_CONFIG_PATH = Path.home() / ".config" / "aman" / "config.json" -CONFIG_RELOAD_POLL_INTERVAL_SEC = 1.0 RECORD_TIMEOUT_SEC = 300 STT_LANGUAGE = "en" TRAY_UPDATE_MS = 250 diff --git a/src/desktop.py b/src/desktop.py index 1ad650d..23ac5f0 100644 --- a/src/desktop.py +++ b/src/desktop.py @@ -5,7 +5,7 @@ from typing import Callable, Protocol class DesktopAdapter(Protocol): - def set_hotkeys(self, bindings: dict[str, Callable[[], None]]) -> None: + def start_hotkey_listener(self, hotkey: str, callback: Callable[[], None]) -> None: raise NotImplementedError def start_cancel_listener(self, callback: Callable[[], None]) -> None: diff --git a/src/desktop_wayland.py b/src/desktop_wayland.py index 871a695..1da88a8 100644 --- a/src/desktop_wayland.py +++ b/src/desktop_wayland.py @@ -4,7 +4,7 @@ from typing import Callable class WaylandAdapter: - def set_hotkeys(self, _bindings: dict[str, Callable[[], None]]) -> None: + def start_hotkey_listener(self, _hotkey: str, _callback: Callable[[], None]) -> None: raise SystemExit("Wayland hotkeys are not supported yet.") def start_cancel_listener(self, _callback: Callable[[], None]) -> None: diff --git a/src/desktop_x11.py b/src/desktop_x11.py index 323371b..d79ca0d 100644 --- a/src/desktop_x11.py +++ b/src/desktop_x11.py @@ -4,7 +4,7 @@ import logging import threading import time import warnings -from typing import Any, Callable, Iterable +from typing import Callable, Iterable import gi from Xlib import X, XK, display @@ -43,8 +43,6 @@ class X11Adapter: self.indicator = None self.status_icon = None self.menu = None - self._hotkey_listener_lock = threading.Lock() - self._hotkey_listeners: dict[str, dict[str, Any]] = {} if AppIndicator3 is not None: self.indicator = AppIndicator3.Indicator.new( "aman", @@ -67,29 +65,15 @@ class X11Adapter: if self.menu: self.menu.popup(None, None, None, None, 0, _time) - def set_hotkeys(self, bindings: dict[str, Callable[[], None]]) -> None: - if not isinstance(bindings, dict): - raise ValueError("bindings must be a dictionary") - next_listeners: dict[str, dict[str, Any]] = {} - try: - for hotkey, callback in bindings.items(): - if not callable(callback): - raise ValueError(f"callback for hotkey {hotkey} must be callable") - next_listeners[hotkey] = self._spawn_hotkey_listener(hotkey, callback) - except Exception: - for listener in next_listeners.values(): - self._stop_hotkey_listener(listener) - raise - - with self._hotkey_listener_lock: - previous = self._hotkey_listeners - self._hotkey_listeners = next_listeners - for listener in previous.values(): - self._stop_hotkey_listener(listener) + def start_hotkey_listener(self, hotkey: str, callback: Callable[[], None]) -> None: + mods, keysym = self._parse_hotkey(hotkey) + self._validate_hotkey_registration(mods, keysym) + thread = threading.Thread(target=self._listen, args=(mods, keysym, callback), daemon=True) + thread.start() def start_cancel_listener(self, callback: Callable[[], None]) -> None: mods, keysym = self._parse_hotkey("Escape") - thread = threading.Thread(target=self._listen, args=(mods, keysym, callback, threading.Event()), daemon=True) + thread = threading.Thread(target=self._listen, args=(mods, keysym, callback), daemon=True) thread.start() def inject_text( @@ -143,14 +127,7 @@ class X11Adapter: finally: self.request_quit() - def _listen( - self, - mods: int, - keysym: int, - callback: Callable[[], None], - stop_event: threading.Event, - listener_meta: dict[str, Any] | None = None, - ) -> None: + def _listen(self, mods: int, keysym: int, callback: Callable[[], None]) -> None: disp = None root = None keycode = None @@ -158,26 +135,14 @@ class X11Adapter: disp = display.Display() root = disp.screen().root keycode = self._grab_hotkey(disp, root, mods, keysym) - if listener_meta is not None: - listener_meta["display"] = disp - listener_meta["root"] = root - listener_meta["keycode"] = keycode - listener_meta["ready"].set() - while not stop_event.is_set(): - if disp.pending_events() == 0: - time.sleep(0.05) - continue + while True: ev = disp.next_event() if ev.type == X.KeyPress and ev.detail == keycode: state = ev.state & ~(X.LockMask | X.Mod2Mask) if state == mods: callback() except Exception as exc: - if listener_meta is not None: - listener_meta["error"] = exc - listener_meta["ready"].set() - if not stop_event.is_set(): - logging.error("hotkey listener stopped: %s", exc) + logging.error("hotkey listener stopped: %s", exc) finally: if root is not None and keycode is not None and disp is not None: try: @@ -185,13 +150,6 @@ class X11Adapter: disp.sync() except Exception: pass - if disp is not None: - try: - disp.close() - except Exception: - pass - if listener_meta is not None: - listener_meta["ready"].set() def _parse_hotkey(self, hotkey: str): mods = 0 @@ -227,51 +185,6 @@ class X11Adapter: except Exception: pass - def _spawn_hotkey_listener(self, hotkey: str, callback: Callable[[], None]) -> dict[str, Any]: - mods, keysym = self._parse_hotkey(hotkey) - self._validate_hotkey_registration(mods, keysym) - stop_event = threading.Event() - listener_meta: dict[str, Any] = { - "hotkey": hotkey, - "mods": mods, - "keysym": keysym, - "stop_event": stop_event, - "display": None, - "root": None, - "keycode": None, - "error": None, - "ready": threading.Event(), - } - thread = threading.Thread( - target=self._listen, - args=(mods, keysym, callback, stop_event, listener_meta), - daemon=True, - ) - listener_meta["thread"] = thread - thread.start() - if not listener_meta["ready"].wait(timeout=2.0): - stop_event.set() - raise RuntimeError("hotkey listener setup timed out") - if listener_meta["error"] is not None: - stop_event.set() - raise listener_meta["error"] - if listener_meta["keycode"] is None: - stop_event.set() - raise RuntimeError("hotkey listener failed to initialize") - return listener_meta - - def _stop_hotkey_listener(self, listener: dict[str, Any]) -> None: - listener["stop_event"].set() - disp = listener.get("display") - if disp is not None: - try: - disp.close() - except Exception: - pass - thread = listener.get("thread") - if thread is not None: - thread.join(timeout=1.0) - def _grab_hotkey(self, disp, root, mods, keysym): keycode = disp.keysym_to_keycode(keysym) if keycode == 0: diff --git a/src/engine.py b/src/engine.py deleted file mode 100644 index de54c90..0000000 --- a/src/engine.py +++ /dev/null @@ -1,86 +0,0 @@ -from __future__ import annotations - -import logging -from dataclasses import dataclass -from typing import Any, Callable - - -ALLOWED_FAILURE_POLICIES = {"best_effort", "strict"} - - -@dataclass(frozen=True) -class PipelineOptions: - failure_policy: str = "best_effort" - - def __post_init__(self): - policy = (self.failure_policy or "").strip().lower() - if policy not in ALLOWED_FAILURE_POLICIES: - allowed = ", ".join(sorted(ALLOWED_FAILURE_POLICIES)) - raise ValueError(f"failure_policy must be one of: {allowed}") - object.__setattr__(self, "failure_policy", policy) - - -@dataclass(frozen=True) -class PipelineBinding: - hotkey: str - handler: Callable[[Any, "PipelineLib"], str] - options: PipelineOptions - - -class PipelineLib: - def __init__( - self, - *, - transcribe_fn: Callable[[Any], str], - llm_fn: Callable[..., str], - ): - self._transcribe_fn = transcribe_fn - self._llm_fn = llm_fn - - def transcribe( - self, - audio: Any, - *, - hints: list[str] | None = None, - whisper_opts: dict[str, Any] | None = None, - ) -> str: - return self._transcribe_fn( - audio, - hints=hints, - whisper_opts=whisper_opts, - ) - - def llm( - self, - *, - system_prompt: str, - user_prompt: str, - llm_opts: dict[str, Any] | None = None, - ) -> str: - return self._llm_fn( - system_prompt=system_prompt, - user_prompt=user_prompt, - llm_opts=llm_opts, - ) - - -class Engine: - def __init__(self, lib: PipelineLib): - self.lib = lib - - def run(self, binding: PipelineBinding, audio: Any) -> str: - try: - output = binding.handler(audio, self.lib) - except Exception as exc: - if binding.options.failure_policy == "strict": - raise - logging.error("pipeline failed for hotkey %s: %s", binding.hotkey, exc) - return "" - - if output is None: - return "" - if not isinstance(output, str): - raise RuntimeError( - f"pipeline for hotkey {binding.hotkey} returned non-string output" - ) - return output diff --git a/src/notify.py b/src/notify.py deleted file mode 100644 index 27469a2..0000000 --- a/src/notify.py +++ /dev/null @@ -1,40 +0,0 @@ -from __future__ import annotations - - -class DesktopNotifier: - def __init__(self, app_name: str = "aman"): - self._app_name = app_name - self._notify = None - self._ready = False - self._attempted = False - - def send(self, title: str, body: str, *, error: bool = False) -> bool: - _ = error - notify = self._ensure_backend() - if notify is None: - return False - try: - notification = notify.Notification.new(title, body, None) - notification.show() - return True - except Exception: - return False - - def _ensure_backend(self): - if self._attempted: - return self._notify if self._ready else None - - self._attempted = True - try: - import gi - - gi.require_version("Notify", "0.7") - from gi.repository import Notify # type: ignore[import-not-found] - - if Notify.init(self._app_name): - self._notify = Notify - self._ready = True - except Exception: - self._notify = None - self._ready = False - return self._notify if self._ready else None diff --git a/src/pipelines_runtime.py b/src/pipelines_runtime.py deleted file mode 100644 index 8aa7985..0000000 --- a/src/pipelines_runtime.py +++ /dev/null @@ -1,97 +0,0 @@ -from __future__ import annotations - -import hashlib -import importlib.util -from pathlib import Path -from types import ModuleType -from typing import Any, Callable - -from engine import PipelineBinding, PipelineOptions - - -DEFAULT_PIPELINES_PATH = Path.home() / ".config" / "aman" / "pipelines.py" - - -def fingerprint(path: Path) -> str | None: - try: - data = path.read_bytes() - except FileNotFoundError: - return None - return hashlib.sha256(data).hexdigest() - - -def load_bindings( - *, - path: Path, - default_hotkey: str, - default_handler_factory: Callable[[], Callable[[Any, Any], str]], -) -> dict[str, PipelineBinding]: - if not path.exists(): - handler = default_handler_factory() - options = PipelineOptions(failure_policy="best_effort") - return { - default_hotkey: PipelineBinding( - hotkey=default_hotkey, - handler=handler, - options=options, - ) - } - - module = _load_module(path) - return _bindings_from_module(module) - - -def _load_module(path: Path) -> ModuleType: - module_name = f"aman_user_pipelines_{hash(path.resolve())}" - spec = importlib.util.spec_from_file_location(module_name, str(path)) - if spec is None or spec.loader is None: - raise ValueError(f"unable to load pipelines module: {path}") - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - return module - - -def _bindings_from_module(module: ModuleType) -> dict[str, PipelineBinding]: - raw_pipelines = getattr(module, "HOTKEY_PIPELINES", None) - if not isinstance(raw_pipelines, dict): - raise ValueError("HOTKEY_PIPELINES must be a dictionary") - if not raw_pipelines: - raise ValueError("HOTKEY_PIPELINES cannot be empty") - - raw_options = getattr(module, "PIPELINE_OPTIONS", {}) - if raw_options is None: - raw_options = {} - if not isinstance(raw_options, dict): - raise ValueError("PIPELINE_OPTIONS must be a dictionary when provided") - - bindings: dict[str, PipelineBinding] = {} - for key, handler in raw_pipelines.items(): - hotkey = _validate_hotkey(key) - if not callable(handler): - raise ValueError(f"HOTKEY_PIPELINES[{hotkey}] must be callable") - option_data = raw_options.get(hotkey, {}) - options = _parse_options(hotkey, option_data) - bindings[hotkey] = PipelineBinding( - hotkey=hotkey, - handler=handler, - options=options, - ) - return bindings - - -def _validate_hotkey(value: Any) -> str: - if not isinstance(value, str): - raise ValueError("pipeline hotkey keys must be strings") - hotkey = value.strip() - if not hotkey: - raise ValueError("pipeline hotkey keys cannot be empty") - return hotkey - - -def _parse_options(hotkey: str, value: Any) -> PipelineOptions: - if value is None: - return PipelineOptions() - if not isinstance(value, dict): - raise ValueError(f"PIPELINE_OPTIONS[{hotkey}] must be an object") - failure_policy = value.get("failure_policy", "best_effort") - return PipelineOptions(failure_policy=failure_policy) diff --git a/src/vocabulary.py b/src/vocabulary.py index 2629eb6..1a7ebb2 100644 --- a/src/vocabulary.py +++ b/src/vocabulary.py @@ -4,7 +4,101 @@ import re from dataclasses import dataclass from typing import Iterable -from config import VocabularyConfig +from config import DomainInferenceConfig, VocabularyConfig + + +DOMAIN_GENERAL = "general" +DOMAIN_PERSONAL_NAMES = "personal_names" +DOMAIN_SOFTWARE_DEV = "software_dev" +DOMAIN_OPS_INFRA = "ops_infra" +DOMAIN_BUSINESS = "business" +DOMAIN_MEDICAL_LEGAL = "medical_legal" + +DOMAIN_ORDER = ( + DOMAIN_PERSONAL_NAMES, + DOMAIN_SOFTWARE_DEV, + DOMAIN_OPS_INFRA, + DOMAIN_BUSINESS, + DOMAIN_MEDICAL_LEGAL, +) + +DOMAIN_KEYWORDS = { + DOMAIN_SOFTWARE_DEV: { + "api", + "bug", + "code", + "commit", + "docker", + "function", + "git", + "github", + "javascript", + "python", + "refactor", + "repository", + "typescript", + "unit", + "test", + }, + DOMAIN_OPS_INFRA: { + "cluster", + "container", + "deploy", + "deployment", + "incident", + "kubernetes", + "monitoring", + "nginx", + "pod", + "prod", + "service", + "systemd", + "terraform", + }, + DOMAIN_BUSINESS: { + "budget", + "client", + "deadline", + "finance", + "invoice", + "meeting", + "milestone", + "project", + "quarter", + "roadmap", + "sales", + "stakeholder", + }, + DOMAIN_MEDICAL_LEGAL: { + "agreement", + "case", + "claim", + "compliance", + "contract", + "diagnosis", + "liability", + "patient", + "prescription", + "regulation", + "symptom", + "treatment", + }, +} + +DOMAIN_PHRASES = { + DOMAIN_SOFTWARE_DEV: ("pull request", "code review", "integration test"), + DOMAIN_OPS_INFRA: ("on call", "service restart", "roll back"), + DOMAIN_BUSINESS: ("follow up", "action items", "meeting notes"), + DOMAIN_MEDICAL_LEGAL: ("terms and conditions", "medical record", "legal review"), +} + +GREETING_TOKENS = {"hello", "hi", "hey", "good morning", "good afternoon", "good evening"} + + +@dataclass(frozen=True) +class DomainResult: + name: str + confidence: float @dataclass(frozen=True) @@ -14,9 +108,10 @@ class _ReplacementView: class VocabularyEngine: - def __init__(self, vocab_cfg: VocabularyConfig): + def __init__(self, vocab_cfg: VocabularyConfig, domain_cfg: DomainInferenceConfig): self._replacements = [_ReplacementView(r.source, r.target) for r in vocab_cfg.replacements] self._terms = list(vocab_cfg.terms) + self._domain_enabled = bool(domain_cfg.enabled) self._replacement_map = { _normalize_key(rule.source): rule.target for rule in self._replacements @@ -66,6 +161,55 @@ class VocabularyEngine: used += addition return "\n".join(out) + def infer_domain(self, text: str) -> DomainResult: + if not self._domain_enabled: + return DomainResult(name=DOMAIN_GENERAL, confidence=0.0) + + normalized = text.casefold() + tokens = re.findall(r"[a-z0-9+#./_-]+", normalized) + if not tokens: + return DomainResult(name=DOMAIN_GENERAL, confidence=0.0) + + scores = {domain: 0 for domain in DOMAIN_ORDER} + for token in tokens: + for domain, keywords in DOMAIN_KEYWORDS.items(): + if token in keywords: + scores[domain] += 2 + + for domain, phrases in DOMAIN_PHRASES.items(): + for phrase in phrases: + if phrase in normalized: + scores[domain] += 2 + + if any(token in GREETING_TOKENS for token in tokens): + scores[DOMAIN_PERSONAL_NAMES] += 1 + + # Boost domains from configured dictionary terms and replacement targets. + dictionary_tokens = self._dictionary_tokens() + for token in dictionary_tokens: + for domain, keywords in DOMAIN_KEYWORDS.items(): + if token in keywords and token in tokens: + scores[domain] += 1 + + top_domain = DOMAIN_GENERAL + top_score = 0 + total_score = 0 + for domain in DOMAIN_ORDER: + score = scores[domain] + total_score += score + if score > top_score: + top_score = score + top_domain = domain + + if top_score < 2 or total_score == 0: + return DomainResult(name=DOMAIN_GENERAL, confidence=0.0) + + confidence = top_score / total_score + if confidence < 0.45: + return DomainResult(name=DOMAIN_GENERAL, confidence=0.0) + + return DomainResult(name=top_domain, confidence=round(confidence, 2)) + def _build_stt_hotwords(self, *, limit: int, char_budget: int) -> str: items = _dedupe_preserve_order( [rule.target for rule in self._replacements] + self._terms @@ -92,6 +236,19 @@ class VocabularyEngine: return "" return prefix + hotwords + def _dictionary_tokens(self) -> set[str]: + values: list[str] = [] + for rule in self._replacements: + values.append(rule.source) + values.append(rule.target) + values.extend(self._terms) + + tokens: set[str] = set() + for value in values: + for token in re.findall(r"[a-z0-9+#./_-]+", value.casefold()): + tokens.add(token) + return tokens + def _build_replacement_pattern(sources: Iterable[str]) -> re.Pattern[str] | None: unique_sources = _dedupe_preserve_order(list(sources)) diff --git a/tests/test_aman.py b/tests/test_aman.py index d0147ed..d9a9867 100644 --- a/tests/test_aman.py +++ b/tests/test_aman.py @@ -1,4 +1,3 @@ -import json import os import sys import tempfile @@ -12,25 +11,14 @@ if str(SRC) not in sys.path: sys.path.insert(0, str(SRC)) import aman -from config import Config, VocabularyReplacement, redacted_dict -from engine import PipelineBinding, PipelineOptions +from config import Config, VocabularyReplacement class FakeDesktop: def __init__(self): self.inject_calls = [] - self.hotkey_updates = [] - self.hotkeys = {} - self.cancel_callback = None self.quit_calls = 0 - def set_hotkeys(self, bindings): - self.hotkeys = dict(bindings) - self.hotkey_updates.append(tuple(sorted(bindings.keys()))) - - def start_cancel_listener(self, callback): - self.cancel_callback = callback - def inject_text( self, text: str, @@ -88,30 +76,12 @@ class FakeAIProcessor: def process(self, text, lang="en", **_kwargs): return text - def chat(self, *, system_prompt, user_prompt, llm_opts=None): - _ = system_prompt - opts = llm_opts or {} - if "response_format" in opts: - payload = json.loads(user_prompt) - transcript = payload.get("transcript", "") - return json.dumps({"cleaned_text": transcript}) - return "general" - class FakeAudio: def __init__(self, size: int): self.size = size -class FakeNotifier: - def __init__(self): - self.events = [] - - def send(self, title, body, *, error=False): - self.events.append((title, body, error)) - return True - - class DaemonTests(unittest.TestCase): def _config(self) -> Config: cfg = Config() @@ -133,23 +103,19 @@ class DaemonTests(unittest.TestCase): @patch("aman.stop_audio_recording", return_value=FakeAudio(8)) @patch("aman.start_audio_recording", return_value=(object(), object())) - def test_hotkey_start_stop_injects_text(self, _start_mock, _stop_mock): + def test_toggle_start_stop_injects_text(self, _start_mock, _stop_mock): desktop = FakeDesktop() daemon = self._build_daemon(desktop, FakeModel(), verbose=False) daemon._start_stop_worker = ( lambda stream, record, trigger, process_audio: daemon._stop_and_process( - stream, - record, - trigger, - process_audio, - daemon._pending_worker_hotkey, + stream, record, trigger, process_audio ) ) - daemon.handle_hotkey(daemon.cfg.daemon.hotkey) + daemon.toggle() self.assertEqual(daemon.get_state(), aman.State.RECORDING) - daemon.handle_hotkey(daemon.cfg.daemon.hotkey) + daemon.toggle() self.assertEqual(daemon.get_state(), aman.State.IDLE) self.assertEqual(desktop.inject_calls, [("hello world", "clipboard", False)]) @@ -165,7 +131,7 @@ class DaemonTests(unittest.TestCase): ) ) - daemon.handle_hotkey(daemon.cfg.daemon.hotkey) + daemon.toggle() self.assertEqual(daemon.get_state(), aman.State.RECORDING) self.assertTrue(daemon.shutdown(timeout=0.2)) @@ -187,8 +153,8 @@ class DaemonTests(unittest.TestCase): ) ) - daemon.handle_hotkey(daemon.cfg.daemon.hotkey) - daemon.handle_hotkey(daemon.cfg.daemon.hotkey) + daemon.toggle() + daemon.toggle() self.assertEqual(desktop.inject_calls, [("good morning Marta", "clipboard", False)]) @@ -257,8 +223,8 @@ class DaemonTests(unittest.TestCase): ) ) - daemon.handle_hotkey(daemon.cfg.daemon.hotkey) - daemon.handle_hotkey(daemon.cfg.daemon.hotkey) + daemon.toggle() + daemon.toggle() self.assertEqual(desktop.inject_calls, [("hello world", "clipboard", True)]) @@ -273,161 +239,6 @@ class DaemonTests(unittest.TestCase): any("DEBUG:root:state: idle -> recording" in line for line in logs.output) ) - def test_hotkey_dispatches_to_matching_pipeline(self): - desktop = FakeDesktop() - daemon = self._build_daemon(desktop, FakeModel(), verbose=False) - daemon.set_pipeline_bindings( - { - "Super+m": PipelineBinding( - hotkey="Super+m", - handler=lambda _audio, _lib: "default", - options=PipelineOptions(failure_policy="best_effort"), - ), - "Super+Shift+m": PipelineBinding( - hotkey="Super+Shift+m", - handler=lambda _audio, _lib: "caps", - options=PipelineOptions(failure_policy="best_effort"), - ), - } - ) - - out = daemon._run_pipeline(object(), "Super+Shift+m") - self.assertEqual(out, "caps") - - def test_try_apply_config_applies_new_runtime_settings(self): - desktop = FakeDesktop() - cfg = self._config() - daemon = self._build_daemon(desktop, FakeModel(), cfg=cfg, verbose=False) - daemon._stt_hint_kwargs_cache = {"hotwords": "old"} - - candidate = Config() - candidate.injection.backend = "injection" - candidate.vocabulary.replacements = [ - VocabularyReplacement(source="Martha", target="Marta"), - ] - - status, changed, error = daemon.try_apply_config(candidate) - - self.assertEqual(status, "applied") - self.assertEqual(error, "") - self.assertIn("injection.backend", changed) - self.assertIn("vocabulary.replacements", changed) - self.assertEqual(daemon.cfg.injection.backend, "injection") - self.assertEqual( - daemon.vocabulary.apply_deterministic_replacements("Martha"), - "Marta", - ) - self.assertIsNone(daemon._stt_hint_kwargs_cache) - - def test_try_apply_config_reloads_stt_model(self): - desktop = FakeDesktop() - cfg = self._config() - daemon = self._build_daemon(desktop, FakeModel(), cfg=cfg, verbose=False) - - candidate = Config() - candidate.stt.model = "small" - next_model = FakeModel(text="from-new-model") - with patch("aman._build_whisper_model", return_value=next_model): - status, changed, error = daemon.try_apply_config(candidate) - - self.assertEqual(status, "applied") - self.assertEqual(error, "") - self.assertIn("stt.model", changed) - self.assertIs(daemon.model, next_model) - - def test_try_apply_config_is_deferred_while_busy(self): - desktop = FakeDesktop() - daemon = self._build_daemon(desktop, FakeModel(), verbose=False) - daemon.set_state(aman.State.RECORDING) - - status, changed, error = daemon.try_apply_config(Config()) - - self.assertEqual(status, "deferred") - self.assertEqual(changed, []) - self.assertIn("busy", error) - - -class ConfigReloaderTests(unittest.TestCase): - def _write_config(self, path: Path, cfg: Config): - path.parent.mkdir(parents=True, exist_ok=True) - path.write_text(json.dumps(redacted_dict(cfg)), encoding="utf-8") - - def _daemon_for_path(self, path: Path) -> tuple[aman.Daemon, FakeDesktop]: - cfg = Config() - self._write_config(path, cfg) - desktop = FakeDesktop() - with patch("aman._build_whisper_model", return_value=FakeModel()), patch( - "aman.LlamaProcessor", return_value=FakeAIProcessor() - ): - daemon = aman.Daemon(cfg, desktop, verbose=False) - return daemon, desktop - - def test_reloader_applies_changed_config(self): - with tempfile.TemporaryDirectory() as td: - path = Path(td) / "config.json" - daemon, _desktop = self._daemon_for_path(path) - notifier = FakeNotifier() - reloader = aman.ConfigReloader( - daemon=daemon, - config_path=path, - notifier=notifier, - poll_interval_sec=0.01, - ) - - updated = Config() - updated.injection.backend = "injection" - self._write_config(path, updated) - reloader.tick() - - self.assertEqual(daemon.cfg.injection.backend, "injection") - self.assertTrue(any(evt[0] == "Config Reloaded" and not evt[2] for evt in notifier.events)) - - def test_reloader_keeps_last_good_config_when_invalid(self): - with tempfile.TemporaryDirectory() as td: - path = Path(td) / "config.json" - daemon, _desktop = self._daemon_for_path(path) - notifier = FakeNotifier() - reloader = aman.ConfigReloader( - daemon=daemon, - config_path=path, - notifier=notifier, - poll_interval_sec=0.01, - ) - - path.write_text('{"injection":{"backend":"invalid"}}', encoding="utf-8") - reloader.tick() - self.assertEqual(daemon.cfg.injection.backend, "clipboard") - fail_events = [evt for evt in notifier.events if evt[0] == "Config Reload Failed"] - self.assertEqual(len(fail_events), 1) - - reloader.tick() - fail_events = [evt for evt in notifier.events if evt[0] == "Config Reload Failed"] - self.assertEqual(len(fail_events), 1) - - def test_reloader_defers_apply_until_idle(self): - with tempfile.TemporaryDirectory() as td: - path = Path(td) / "config.json" - daemon, _desktop = self._daemon_for_path(path) - notifier = FakeNotifier() - reloader = aman.ConfigReloader( - daemon=daemon, - config_path=path, - notifier=notifier, - poll_interval_sec=0.01, - ) - - updated = Config() - updated.injection.backend = "injection" - self._write_config(path, updated) - - daemon.set_state(aman.State.RECORDING) - reloader.tick() - self.assertEqual(daemon.cfg.injection.backend, "clipboard") - - daemon.set_state(aman.State.IDLE) - reloader.tick() - self.assertEqual(daemon.cfg.injection.backend, "injection") - class LockTests(unittest.TestCase): def test_lock_rejects_second_instance(self): diff --git a/tests/test_config.py b/tests/test_config.py index 1e6dd13..f65a9f8 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -26,6 +26,7 @@ class ConfigTests(unittest.TestCase): self.assertFalse(cfg.injection.remove_transcription_from_clipboard) self.assertEqual(cfg.vocabulary.replacements, []) self.assertEqual(cfg.vocabulary.terms, []) + self.assertTrue(cfg.domain_inference.enabled) self.assertTrue(missing.exists()) written = json.loads(missing.read_text(encoding="utf-8")) @@ -47,6 +48,7 @@ class ConfigTests(unittest.TestCase): ], "terms": ["Systemd", "Kubernetes"], }, + "domain_inference": {"enabled": True}, } with tempfile.TemporaryDirectory() as td: path = Path(td) / "config.json" @@ -64,6 +66,7 @@ class ConfigTests(unittest.TestCase): self.assertEqual(cfg.vocabulary.replacements[0].source, "Martha") self.assertEqual(cfg.vocabulary.replacements[0].target, "Marta") self.assertEqual(cfg.vocabulary.terms, ["Systemd", "Kubernetes"]) + self.assertTrue(cfg.domain_inference.enabled) def test_super_modifier_hotkey_is_valid(self): payload = {"daemon": {"hotkey": "Super+m"}} @@ -95,6 +98,28 @@ class ConfigTests(unittest.TestCase): ): load(str(path)) + def test_loads_legacy_keys(self): + payload = { + "hotkey": "Alt+m", + "input": "Mic", + "whisper_model": "tiny", + "whisper_device": "cpu", + "injection_backend": "clipboard", + } + with tempfile.TemporaryDirectory() as td: + path = Path(td) / "config.json" + path.write_text(json.dumps(payload), encoding="utf-8") + + cfg = load(str(path)) + + self.assertEqual(cfg.daemon.hotkey, "Alt+m") + self.assertEqual(cfg.recording.input, "Mic") + self.assertEqual(cfg.stt.model, "tiny") + self.assertEqual(cfg.stt.device, "cpu") + self.assertEqual(cfg.injection.backend, "clipboard") + self.assertFalse(cfg.injection.remove_transcription_from_clipboard) + self.assertEqual(cfg.vocabulary.replacements, []) + def test_invalid_injection_backend_raises(self): payload = {"injection": {"backend": "invalid"}} with tempfile.TemporaryDirectory() as td: @@ -113,20 +138,41 @@ class ConfigTests(unittest.TestCase): with self.assertRaisesRegex(ValueError, "injection.remove_transcription_from_clipboard"): load(str(path)) - def test_unknown_top_level_fields_are_ignored(self): - payload = { - "custom_a": {"enabled": True}, - "custom_b": {"nested": "value"}, - "custom_c": 123, - } + def test_removed_ai_section_raises(self): + payload = {"ai": {"enabled": True}} with tempfile.TemporaryDirectory() as td: path = Path(td) / "config.json" path.write_text(json.dumps(payload), encoding="utf-8") - cfg = load(str(path)) + with self.assertRaisesRegex(ValueError, "ai section is no longer supported"): + load(str(path)) - self.assertEqual(cfg.daemon.hotkey, "Cmd+m") - self.assertEqual(cfg.injection.backend, "clipboard") + def test_removed_legacy_ai_enabled_raises(self): + payload = {"ai_enabled": True} + with tempfile.TemporaryDirectory() as td: + path = Path(td) / "config.json" + path.write_text(json.dumps(payload), encoding="utf-8") + + with self.assertRaisesRegex(ValueError, "ai_enabled is no longer supported"): + load(str(path)) + + def test_removed_logging_section_raises(self): + payload = {"logging": {"log_transcript": True}} + with tempfile.TemporaryDirectory() as td: + path = Path(td) / "config.json" + path.write_text(json.dumps(payload), encoding="utf-8") + + with self.assertRaisesRegex(ValueError, "no longer supported"): + load(str(path)) + + def test_removed_legacy_log_transcript_raises(self): + payload = {"log_transcript": True} + with tempfile.TemporaryDirectory() as td: + path = Path(td) / "config.json" + path.write_text(json.dumps(payload), encoding="utf-8") + + with self.assertRaisesRegex(ValueError, "no longer supported"): + load(str(path)) def test_conflicting_replacements_raise(self): payload = { @@ -178,15 +224,32 @@ class ConfigTests(unittest.TestCase): with self.assertRaisesRegex(ValueError, "wildcard"): load(str(path)) - def test_unknown_vocabulary_fields_are_ignored(self): - payload = {"vocabulary": {"custom_limit": 100, "custom_extra": 200, "terms": ["Docker"]}} + def test_removed_domain_mode_raises(self): + payload = {"domain_inference": {"mode": "heuristic"}} with tempfile.TemporaryDirectory() as td: path = Path(td) / "config.json" path.write_text(json.dumps(payload), encoding="utf-8") - cfg = load(str(path)) + with self.assertRaisesRegex(ValueError, "domain_inference.mode is no longer supported"): + load(str(path)) - self.assertEqual(cfg.vocabulary.terms, ["Docker"]) + def test_removed_vocabulary_max_rules_raises(self): + payload = {"vocabulary": {"max_rules": 100}} + with tempfile.TemporaryDirectory() as td: + path = Path(td) / "config.json" + path.write_text(json.dumps(payload), encoding="utf-8") + + with self.assertRaisesRegex(ValueError, "vocabulary.max_rules is no longer supported"): + load(str(path)) + + def test_removed_vocabulary_max_terms_raises(self): + payload = {"vocabulary": {"max_terms": 100}} + with tempfile.TemporaryDirectory() as td: + path = Path(td) / "config.json" + path.write_text(json.dumps(payload), encoding="utf-8") + + with self.assertRaisesRegex(ValueError, "vocabulary.max_terms is no longer supported"): + load(str(path)) if __name__ == "__main__": diff --git a/tests/test_engine.py b/tests/test_engine.py deleted file mode 100644 index 4253479..0000000 --- a/tests/test_engine.py +++ /dev/null @@ -1,92 +0,0 @@ -import sys -import unittest -from pathlib import Path - -ROOT = Path(__file__).resolve().parents[1] -SRC = ROOT / "src" -if str(SRC) not in sys.path: - sys.path.insert(0, str(SRC)) - -from engine import Engine, PipelineBinding, PipelineLib, PipelineOptions - - -class EngineTests(unittest.TestCase): - def test_best_effort_pipeline_failure_returns_empty_string(self): - lib = PipelineLib( - transcribe_fn=lambda *_args, **_kwargs: "text", - llm_fn=lambda **_kwargs: "result", - ) - engine = Engine(lib) - - def failing_pipeline(_audio, _lib): - raise RuntimeError("boom") - - binding = PipelineBinding( - hotkey="Cmd+m", - handler=failing_pipeline, - options=PipelineOptions(failure_policy="best_effort"), - ) - out = engine.run(binding, object()) - self.assertEqual(out, "") - - def test_strict_pipeline_failure_raises(self): - lib = PipelineLib( - transcribe_fn=lambda *_args, **_kwargs: "text", - llm_fn=lambda **_kwargs: "result", - ) - engine = Engine(lib) - - def failing_pipeline(_audio, _lib): - raise RuntimeError("boom") - - binding = PipelineBinding( - hotkey="Cmd+m", - handler=failing_pipeline, - options=PipelineOptions(failure_policy="strict"), - ) - with self.assertRaisesRegex(RuntimeError, "boom"): - engine.run(binding, object()) - - def test_pipeline_lib_forwards_arguments(self): - seen = {} - - def transcribe_fn(audio, *, hints=None, whisper_opts=None): - seen["audio"] = audio - seen["hints"] = hints - seen["whisper_opts"] = whisper_opts - return "hello" - - def llm_fn(*, system_prompt, user_prompt, llm_opts=None): - seen["system_prompt"] = system_prompt - seen["user_prompt"] = user_prompt - seen["llm_opts"] = llm_opts - return "world" - - lib = PipelineLib( - transcribe_fn=transcribe_fn, - llm_fn=llm_fn, - ) - - audio = object() - self.assertEqual( - lib.transcribe(audio, hints=["Docker"], whisper_opts={"vad_filter": True}), - "hello", - ) - self.assertEqual( - lib.llm( - system_prompt="sys", - user_prompt="user", - llm_opts={"temperature": 0.2}, - ), - "world", - ) - self.assertIs(seen["audio"], audio) - self.assertEqual(seen["hints"], ["Docker"]) - self.assertEqual(seen["whisper_opts"], {"vad_filter": True}) - self.assertEqual(seen["system_prompt"], "sys") - self.assertEqual(seen["user_prompt"], "user") - self.assertEqual(seen["llm_opts"], {"temperature": 0.2}) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_pipelines_runtime.py b/tests/test_pipelines_runtime.py deleted file mode 100644 index 4bc644b..0000000 --- a/tests/test_pipelines_runtime.py +++ /dev/null @@ -1,109 +0,0 @@ -import tempfile -import textwrap -import sys -import unittest -from pathlib import Path - -ROOT = Path(__file__).resolve().parents[1] -SRC = ROOT / "src" -if str(SRC) not in sys.path: - sys.path.insert(0, str(SRC)) - -from pipelines_runtime import fingerprint, load_bindings - - -class PipelinesRuntimeTests(unittest.TestCase): - def test_missing_file_uses_default_binding(self): - with tempfile.TemporaryDirectory() as td: - path = Path(td) / "pipelines.py" - called = {"count": 0} - - def factory(): - called["count"] += 1 - - def handler(_audio, _lib): - return "ok" - - return handler - - bindings = load_bindings( - path=path, - default_hotkey="Cmd+m", - default_handler_factory=factory, - ) - - self.assertEqual(list(bindings.keys()), ["Cmd+m"]) - self.assertEqual(called["count"], 1) - - def test_loads_module_bindings_and_options(self): - with tempfile.TemporaryDirectory() as td: - path = Path(td) / "pipelines.py" - path.write_text( - textwrap.dedent( - """ - def p1(audio, lib): - return "one" - - def p2(audio, lib): - return "two" - - HOTKEY_PIPELINES = { - "Cmd+m": p1, - "Ctrl+space": p2, - } - - PIPELINE_OPTIONS = { - "Ctrl+space": {"failure_policy": "strict"}, - } - """ - ), - encoding="utf-8", - ) - - bindings = load_bindings( - path=path, - default_hotkey="Cmd+m", - default_handler_factory=lambda: (lambda _audio, _lib: "default"), - ) - - self.assertEqual(sorted(bindings.keys()), ["Cmd+m", "Ctrl+space"]) - self.assertEqual(bindings["Cmd+m"].options.failure_policy, "best_effort") - self.assertEqual(bindings["Ctrl+space"].options.failure_policy, "strict") - - def test_invalid_options_fail(self): - with tempfile.TemporaryDirectory() as td: - path = Path(td) / "pipelines.py" - path.write_text( - textwrap.dedent( - """ - def p(audio, lib): - return "ok" - HOTKEY_PIPELINES = {"Cmd+m": p} - PIPELINE_OPTIONS = {"Cmd+m": {"failure_policy": "invalid"}} - """ - ), - encoding="utf-8", - ) - - with self.assertRaisesRegex(ValueError, "failure_policy"): - load_bindings( - path=path, - default_hotkey="Cmd+m", - default_handler_factory=lambda: (lambda _audio, _lib: "default"), - ) - - def test_fingerprint_changes_with_content(self): - with tempfile.TemporaryDirectory() as td: - path = Path(td) / "pipelines.py" - self.assertIsNone(fingerprint(path)) - path.write_text("HOTKEY_PIPELINES = {}", encoding="utf-8") - first = fingerprint(path) - path.write_text("HOTKEY_PIPELINES = {'Cmd+m': lambda a, l: ''}", encoding="utf-8") - second = fingerprint(path) - self.assertIsNotNone(first) - self.assertIsNotNone(second) - self.assertNotEqual(first, second) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_vocabulary.py b/tests/test_vocabulary.py index fd66a17..bc33ea6 100644 --- a/tests/test_vocabulary.py +++ b/tests/test_vocabulary.py @@ -7,17 +7,18 @@ SRC = ROOT / "src" if str(SRC) not in sys.path: sys.path.insert(0, str(SRC)) -from config import VocabularyConfig, VocabularyReplacement -from vocabulary import VocabularyEngine +from config import DomainInferenceConfig, VocabularyConfig, VocabularyReplacement +from vocabulary import DOMAIN_GENERAL, VocabularyEngine class VocabularyEngineTests(unittest.TestCase): - def _engine(self, replacements=None, terms=None): + def _engine(self, replacements=None, terms=None, domain_enabled=True): vocab = VocabularyConfig( replacements=replacements or [], terms=terms or [], ) - return VocabularyEngine(vocab) + domain = DomainInferenceConfig(enabled=domain_enabled) + return VocabularyEngine(vocab, domain) def test_boundary_aware_replacement(self): engine = self._engine( @@ -49,5 +50,27 @@ class VocabularyEngineTests(unittest.TestCase): self.assertLessEqual(len(hotwords), 1024) self.assertLessEqual(len(prompt), 600) + def test_domain_inference_general_fallback(self): + engine = self._engine() + result = engine.infer_domain("please call me later") + + self.assertEqual(result.name, DOMAIN_GENERAL) + self.assertEqual(result.confidence, 0.0) + + def test_domain_inference_for_technical_text(self): + engine = self._engine(terms=["Docker", "Systemd"]) + result = engine.infer_domain("restart Docker and systemd service on prod") + + self.assertNotEqual(result.name, DOMAIN_GENERAL) + self.assertGreater(result.confidence, 0.0) + + def test_domain_inference_can_be_disabled(self): + engine = self._engine(domain_enabled=False) + result = engine.infer_domain("please restart docker") + + self.assertEqual(result.name, DOMAIN_GENERAL) + self.assertEqual(result.confidence, 0.0) + + if __name__ == "__main__": unittest.main()