Add pipeline engine and remove legacy compatibility paths
This commit is contained in:
parent
3bc473262d
commit
e221d49020
18 changed files with 1523 additions and 399 deletions
|
|
@ -1,3 +1,4 @@
|
|||
import json
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
|
|
@ -11,14 +12,25 @@ if str(SRC) not in sys.path:
|
|||
sys.path.insert(0, str(SRC))
|
||||
|
||||
import aman
|
||||
from config import Config, VocabularyReplacement
|
||||
from config import Config, VocabularyReplacement, redacted_dict
|
||||
from engine import PipelineBinding, PipelineOptions
|
||||
|
||||
|
||||
class FakeDesktop:
|
||||
def __init__(self):
|
||||
self.inject_calls = []
|
||||
self.hotkey_updates = []
|
||||
self.hotkeys = {}
|
||||
self.cancel_callback = None
|
||||
self.quit_calls = 0
|
||||
|
||||
def set_hotkeys(self, bindings):
|
||||
self.hotkeys = dict(bindings)
|
||||
self.hotkey_updates.append(tuple(sorted(bindings.keys())))
|
||||
|
||||
def start_cancel_listener(self, callback):
|
||||
self.cancel_callback = callback
|
||||
|
||||
def inject_text(
|
||||
self,
|
||||
text: str,
|
||||
|
|
@ -76,12 +88,30 @@ class FakeAIProcessor:
|
|||
def process(self, text, lang="en", **_kwargs):
|
||||
return text
|
||||
|
||||
def chat(self, *, system_prompt, user_prompt, llm_opts=None):
|
||||
_ = system_prompt
|
||||
opts = llm_opts or {}
|
||||
if "response_format" in opts:
|
||||
payload = json.loads(user_prompt)
|
||||
transcript = payload.get("transcript", "")
|
||||
return json.dumps({"cleaned_text": transcript})
|
||||
return "general"
|
||||
|
||||
|
||||
class FakeAudio:
|
||||
def __init__(self, size: int):
|
||||
self.size = size
|
||||
|
||||
|
||||
class FakeNotifier:
|
||||
def __init__(self):
|
||||
self.events = []
|
||||
|
||||
def send(self, title, body, *, error=False):
|
||||
self.events.append((title, body, error))
|
||||
return True
|
||||
|
||||
|
||||
class DaemonTests(unittest.TestCase):
|
||||
def _config(self) -> Config:
|
||||
cfg = Config()
|
||||
|
|
@ -103,19 +133,23 @@ class DaemonTests(unittest.TestCase):
|
|||
|
||||
@patch("aman.stop_audio_recording", return_value=FakeAudio(8))
|
||||
@patch("aman.start_audio_recording", return_value=(object(), object()))
|
||||
def test_toggle_start_stop_injects_text(self, _start_mock, _stop_mock):
|
||||
def test_hotkey_start_stop_injects_text(self, _start_mock, _stop_mock):
|
||||
desktop = FakeDesktop()
|
||||
daemon = self._build_daemon(desktop, FakeModel(), verbose=False)
|
||||
daemon._start_stop_worker = (
|
||||
lambda stream, record, trigger, process_audio: daemon._stop_and_process(
|
||||
stream, record, trigger, process_audio
|
||||
stream,
|
||||
record,
|
||||
trigger,
|
||||
process_audio,
|
||||
daemon._pending_worker_hotkey,
|
||||
)
|
||||
)
|
||||
|
||||
daemon.toggle()
|
||||
daemon.handle_hotkey(daemon.cfg.daemon.hotkey)
|
||||
self.assertEqual(daemon.get_state(), aman.State.RECORDING)
|
||||
|
||||
daemon.toggle()
|
||||
daemon.handle_hotkey(daemon.cfg.daemon.hotkey)
|
||||
|
||||
self.assertEqual(daemon.get_state(), aman.State.IDLE)
|
||||
self.assertEqual(desktop.inject_calls, [("hello world", "clipboard", False)])
|
||||
|
|
@ -131,7 +165,7 @@ class DaemonTests(unittest.TestCase):
|
|||
)
|
||||
)
|
||||
|
||||
daemon.toggle()
|
||||
daemon.handle_hotkey(daemon.cfg.daemon.hotkey)
|
||||
self.assertEqual(daemon.get_state(), aman.State.RECORDING)
|
||||
|
||||
self.assertTrue(daemon.shutdown(timeout=0.2))
|
||||
|
|
@ -153,8 +187,8 @@ class DaemonTests(unittest.TestCase):
|
|||
)
|
||||
)
|
||||
|
||||
daemon.toggle()
|
||||
daemon.toggle()
|
||||
daemon.handle_hotkey(daemon.cfg.daemon.hotkey)
|
||||
daemon.handle_hotkey(daemon.cfg.daemon.hotkey)
|
||||
|
||||
self.assertEqual(desktop.inject_calls, [("good morning Marta", "clipboard", False)])
|
||||
|
||||
|
|
@ -223,8 +257,8 @@ class DaemonTests(unittest.TestCase):
|
|||
)
|
||||
)
|
||||
|
||||
daemon.toggle()
|
||||
daemon.toggle()
|
||||
daemon.handle_hotkey(daemon.cfg.daemon.hotkey)
|
||||
daemon.handle_hotkey(daemon.cfg.daemon.hotkey)
|
||||
|
||||
self.assertEqual(desktop.inject_calls, [("hello world", "clipboard", True)])
|
||||
|
||||
|
|
@ -239,6 +273,161 @@ class DaemonTests(unittest.TestCase):
|
|||
any("DEBUG:root:state: idle -> recording" in line for line in logs.output)
|
||||
)
|
||||
|
||||
def test_hotkey_dispatches_to_matching_pipeline(self):
|
||||
desktop = FakeDesktop()
|
||||
daemon = self._build_daemon(desktop, FakeModel(), verbose=False)
|
||||
daemon.set_pipeline_bindings(
|
||||
{
|
||||
"Super+m": PipelineBinding(
|
||||
hotkey="Super+m",
|
||||
handler=lambda _audio, _lib: "default",
|
||||
options=PipelineOptions(failure_policy="best_effort"),
|
||||
),
|
||||
"Super+Shift+m": PipelineBinding(
|
||||
hotkey="Super+Shift+m",
|
||||
handler=lambda _audio, _lib: "caps",
|
||||
options=PipelineOptions(failure_policy="best_effort"),
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
out = daemon._run_pipeline(object(), "Super+Shift+m")
|
||||
self.assertEqual(out, "caps")
|
||||
|
||||
def test_try_apply_config_applies_new_runtime_settings(self):
|
||||
desktop = FakeDesktop()
|
||||
cfg = self._config()
|
||||
daemon = self._build_daemon(desktop, FakeModel(), cfg=cfg, verbose=False)
|
||||
daemon._stt_hint_kwargs_cache = {"hotwords": "old"}
|
||||
|
||||
candidate = Config()
|
||||
candidate.injection.backend = "injection"
|
||||
candidate.vocabulary.replacements = [
|
||||
VocabularyReplacement(source="Martha", target="Marta"),
|
||||
]
|
||||
|
||||
status, changed, error = daemon.try_apply_config(candidate)
|
||||
|
||||
self.assertEqual(status, "applied")
|
||||
self.assertEqual(error, "")
|
||||
self.assertIn("injection.backend", changed)
|
||||
self.assertIn("vocabulary.replacements", changed)
|
||||
self.assertEqual(daemon.cfg.injection.backend, "injection")
|
||||
self.assertEqual(
|
||||
daemon.vocabulary.apply_deterministic_replacements("Martha"),
|
||||
"Marta",
|
||||
)
|
||||
self.assertIsNone(daemon._stt_hint_kwargs_cache)
|
||||
|
||||
def test_try_apply_config_reloads_stt_model(self):
|
||||
desktop = FakeDesktop()
|
||||
cfg = self._config()
|
||||
daemon = self._build_daemon(desktop, FakeModel(), cfg=cfg, verbose=False)
|
||||
|
||||
candidate = Config()
|
||||
candidate.stt.model = "small"
|
||||
next_model = FakeModel(text="from-new-model")
|
||||
with patch("aman._build_whisper_model", return_value=next_model):
|
||||
status, changed, error = daemon.try_apply_config(candidate)
|
||||
|
||||
self.assertEqual(status, "applied")
|
||||
self.assertEqual(error, "")
|
||||
self.assertIn("stt.model", changed)
|
||||
self.assertIs(daemon.model, next_model)
|
||||
|
||||
def test_try_apply_config_is_deferred_while_busy(self):
|
||||
desktop = FakeDesktop()
|
||||
daemon = self._build_daemon(desktop, FakeModel(), verbose=False)
|
||||
daemon.set_state(aman.State.RECORDING)
|
||||
|
||||
status, changed, error = daemon.try_apply_config(Config())
|
||||
|
||||
self.assertEqual(status, "deferred")
|
||||
self.assertEqual(changed, [])
|
||||
self.assertIn("busy", error)
|
||||
|
||||
|
||||
class ConfigReloaderTests(unittest.TestCase):
|
||||
def _write_config(self, path: Path, cfg: Config):
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_text(json.dumps(redacted_dict(cfg)), encoding="utf-8")
|
||||
|
||||
def _daemon_for_path(self, path: Path) -> tuple[aman.Daemon, FakeDesktop]:
|
||||
cfg = Config()
|
||||
self._write_config(path, cfg)
|
||||
desktop = FakeDesktop()
|
||||
with patch("aman._build_whisper_model", return_value=FakeModel()), patch(
|
||||
"aman.LlamaProcessor", return_value=FakeAIProcessor()
|
||||
):
|
||||
daemon = aman.Daemon(cfg, desktop, verbose=False)
|
||||
return daemon, desktop
|
||||
|
||||
def test_reloader_applies_changed_config(self):
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
path = Path(td) / "config.json"
|
||||
daemon, _desktop = self._daemon_for_path(path)
|
||||
notifier = FakeNotifier()
|
||||
reloader = aman.ConfigReloader(
|
||||
daemon=daemon,
|
||||
config_path=path,
|
||||
notifier=notifier,
|
||||
poll_interval_sec=0.01,
|
||||
)
|
||||
|
||||
updated = Config()
|
||||
updated.injection.backend = "injection"
|
||||
self._write_config(path, updated)
|
||||
reloader.tick()
|
||||
|
||||
self.assertEqual(daemon.cfg.injection.backend, "injection")
|
||||
self.assertTrue(any(evt[0] == "Config Reloaded" and not evt[2] for evt in notifier.events))
|
||||
|
||||
def test_reloader_keeps_last_good_config_when_invalid(self):
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
path = Path(td) / "config.json"
|
||||
daemon, _desktop = self._daemon_for_path(path)
|
||||
notifier = FakeNotifier()
|
||||
reloader = aman.ConfigReloader(
|
||||
daemon=daemon,
|
||||
config_path=path,
|
||||
notifier=notifier,
|
||||
poll_interval_sec=0.01,
|
||||
)
|
||||
|
||||
path.write_text('{"injection":{"backend":"invalid"}}', encoding="utf-8")
|
||||
reloader.tick()
|
||||
self.assertEqual(daemon.cfg.injection.backend, "clipboard")
|
||||
fail_events = [evt for evt in notifier.events if evt[0] == "Config Reload Failed"]
|
||||
self.assertEqual(len(fail_events), 1)
|
||||
|
||||
reloader.tick()
|
||||
fail_events = [evt for evt in notifier.events if evt[0] == "Config Reload Failed"]
|
||||
self.assertEqual(len(fail_events), 1)
|
||||
|
||||
def test_reloader_defers_apply_until_idle(self):
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
path = Path(td) / "config.json"
|
||||
daemon, _desktop = self._daemon_for_path(path)
|
||||
notifier = FakeNotifier()
|
||||
reloader = aman.ConfigReloader(
|
||||
daemon=daemon,
|
||||
config_path=path,
|
||||
notifier=notifier,
|
||||
poll_interval_sec=0.01,
|
||||
)
|
||||
|
||||
updated = Config()
|
||||
updated.injection.backend = "injection"
|
||||
self._write_config(path, updated)
|
||||
|
||||
daemon.set_state(aman.State.RECORDING)
|
||||
reloader.tick()
|
||||
self.assertEqual(daemon.cfg.injection.backend, "clipboard")
|
||||
|
||||
daemon.set_state(aman.State.IDLE)
|
||||
reloader.tick()
|
||||
self.assertEqual(daemon.cfg.injection.backend, "injection")
|
||||
|
||||
|
||||
class LockTests(unittest.TestCase):
|
||||
def test_lock_rejects_second_instance(self):
|
||||
|
|
|
|||
|
|
@ -26,7 +26,6 @@ class ConfigTests(unittest.TestCase):
|
|||
self.assertFalse(cfg.injection.remove_transcription_from_clipboard)
|
||||
self.assertEqual(cfg.vocabulary.replacements, [])
|
||||
self.assertEqual(cfg.vocabulary.terms, [])
|
||||
self.assertTrue(cfg.domain_inference.enabled)
|
||||
|
||||
self.assertTrue(missing.exists())
|
||||
written = json.loads(missing.read_text(encoding="utf-8"))
|
||||
|
|
@ -48,7 +47,6 @@ class ConfigTests(unittest.TestCase):
|
|||
],
|
||||
"terms": ["Systemd", "Kubernetes"],
|
||||
},
|
||||
"domain_inference": {"enabled": True},
|
||||
}
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
path = Path(td) / "config.json"
|
||||
|
|
@ -66,7 +64,6 @@ class ConfigTests(unittest.TestCase):
|
|||
self.assertEqual(cfg.vocabulary.replacements[0].source, "Martha")
|
||||
self.assertEqual(cfg.vocabulary.replacements[0].target, "Marta")
|
||||
self.assertEqual(cfg.vocabulary.terms, ["Systemd", "Kubernetes"])
|
||||
self.assertTrue(cfg.domain_inference.enabled)
|
||||
|
||||
def test_super_modifier_hotkey_is_valid(self):
|
||||
payload = {"daemon": {"hotkey": "Super+m"}}
|
||||
|
|
@ -98,28 +95,6 @@ class ConfigTests(unittest.TestCase):
|
|||
):
|
||||
load(str(path))
|
||||
|
||||
def test_loads_legacy_keys(self):
|
||||
payload = {
|
||||
"hotkey": "Alt+m",
|
||||
"input": "Mic",
|
||||
"whisper_model": "tiny",
|
||||
"whisper_device": "cpu",
|
||||
"injection_backend": "clipboard",
|
||||
}
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
path = Path(td) / "config.json"
|
||||
path.write_text(json.dumps(payload), encoding="utf-8")
|
||||
|
||||
cfg = load(str(path))
|
||||
|
||||
self.assertEqual(cfg.daemon.hotkey, "Alt+m")
|
||||
self.assertEqual(cfg.recording.input, "Mic")
|
||||
self.assertEqual(cfg.stt.model, "tiny")
|
||||
self.assertEqual(cfg.stt.device, "cpu")
|
||||
self.assertEqual(cfg.injection.backend, "clipboard")
|
||||
self.assertFalse(cfg.injection.remove_transcription_from_clipboard)
|
||||
self.assertEqual(cfg.vocabulary.replacements, [])
|
||||
|
||||
def test_invalid_injection_backend_raises(self):
|
||||
payload = {"injection": {"backend": "invalid"}}
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
|
|
@ -138,41 +113,20 @@ class ConfigTests(unittest.TestCase):
|
|||
with self.assertRaisesRegex(ValueError, "injection.remove_transcription_from_clipboard"):
|
||||
load(str(path))
|
||||
|
||||
def test_removed_ai_section_raises(self):
|
||||
payload = {"ai": {"enabled": True}}
|
||||
def test_unknown_top_level_fields_are_ignored(self):
|
||||
payload = {
|
||||
"custom_a": {"enabled": True},
|
||||
"custom_b": {"nested": "value"},
|
||||
"custom_c": 123,
|
||||
}
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
path = Path(td) / "config.json"
|
||||
path.write_text(json.dumps(payload), encoding="utf-8")
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "ai section is no longer supported"):
|
||||
load(str(path))
|
||||
cfg = load(str(path))
|
||||
|
||||
def test_removed_legacy_ai_enabled_raises(self):
|
||||
payload = {"ai_enabled": True}
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
path = Path(td) / "config.json"
|
||||
path.write_text(json.dumps(payload), encoding="utf-8")
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "ai_enabled is no longer supported"):
|
||||
load(str(path))
|
||||
|
||||
def test_removed_logging_section_raises(self):
|
||||
payload = {"logging": {"log_transcript": True}}
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
path = Path(td) / "config.json"
|
||||
path.write_text(json.dumps(payload), encoding="utf-8")
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "no longer supported"):
|
||||
load(str(path))
|
||||
|
||||
def test_removed_legacy_log_transcript_raises(self):
|
||||
payload = {"log_transcript": True}
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
path = Path(td) / "config.json"
|
||||
path.write_text(json.dumps(payload), encoding="utf-8")
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "no longer supported"):
|
||||
load(str(path))
|
||||
self.assertEqual(cfg.daemon.hotkey, "Cmd+m")
|
||||
self.assertEqual(cfg.injection.backend, "clipboard")
|
||||
|
||||
def test_conflicting_replacements_raise(self):
|
||||
payload = {
|
||||
|
|
@ -224,32 +178,15 @@ class ConfigTests(unittest.TestCase):
|
|||
with self.assertRaisesRegex(ValueError, "wildcard"):
|
||||
load(str(path))
|
||||
|
||||
def test_removed_domain_mode_raises(self):
|
||||
payload = {"domain_inference": {"mode": "heuristic"}}
|
||||
def test_unknown_vocabulary_fields_are_ignored(self):
|
||||
payload = {"vocabulary": {"custom_limit": 100, "custom_extra": 200, "terms": ["Docker"]}}
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
path = Path(td) / "config.json"
|
||||
path.write_text(json.dumps(payload), encoding="utf-8")
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "domain_inference.mode is no longer supported"):
|
||||
load(str(path))
|
||||
cfg = load(str(path))
|
||||
|
||||
def test_removed_vocabulary_max_rules_raises(self):
|
||||
payload = {"vocabulary": {"max_rules": 100}}
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
path = Path(td) / "config.json"
|
||||
path.write_text(json.dumps(payload), encoding="utf-8")
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "vocabulary.max_rules is no longer supported"):
|
||||
load(str(path))
|
||||
|
||||
def test_removed_vocabulary_max_terms_raises(self):
|
||||
payload = {"vocabulary": {"max_terms": 100}}
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
path = Path(td) / "config.json"
|
||||
path.write_text(json.dumps(payload), encoding="utf-8")
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "vocabulary.max_terms is no longer supported"):
|
||||
load(str(path))
|
||||
self.assertEqual(cfg.vocabulary.terms, ["Docker"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
92
tests/test_engine.py
Normal file
92
tests/test_engine.py
Normal file
|
|
@ -0,0 +1,92 @@
|
|||
import sys
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
SRC = ROOT / "src"
|
||||
if str(SRC) not in sys.path:
|
||||
sys.path.insert(0, str(SRC))
|
||||
|
||||
from engine import Engine, PipelineBinding, PipelineLib, PipelineOptions
|
||||
|
||||
|
||||
class EngineTests(unittest.TestCase):
|
||||
def test_best_effort_pipeline_failure_returns_empty_string(self):
|
||||
lib = PipelineLib(
|
||||
transcribe_fn=lambda *_args, **_kwargs: "text",
|
||||
llm_fn=lambda **_kwargs: "result",
|
||||
)
|
||||
engine = Engine(lib)
|
||||
|
||||
def failing_pipeline(_audio, _lib):
|
||||
raise RuntimeError("boom")
|
||||
|
||||
binding = PipelineBinding(
|
||||
hotkey="Cmd+m",
|
||||
handler=failing_pipeline,
|
||||
options=PipelineOptions(failure_policy="best_effort"),
|
||||
)
|
||||
out = engine.run(binding, object())
|
||||
self.assertEqual(out, "")
|
||||
|
||||
def test_strict_pipeline_failure_raises(self):
|
||||
lib = PipelineLib(
|
||||
transcribe_fn=lambda *_args, **_kwargs: "text",
|
||||
llm_fn=lambda **_kwargs: "result",
|
||||
)
|
||||
engine = Engine(lib)
|
||||
|
||||
def failing_pipeline(_audio, _lib):
|
||||
raise RuntimeError("boom")
|
||||
|
||||
binding = PipelineBinding(
|
||||
hotkey="Cmd+m",
|
||||
handler=failing_pipeline,
|
||||
options=PipelineOptions(failure_policy="strict"),
|
||||
)
|
||||
with self.assertRaisesRegex(RuntimeError, "boom"):
|
||||
engine.run(binding, object())
|
||||
|
||||
def test_pipeline_lib_forwards_arguments(self):
|
||||
seen = {}
|
||||
|
||||
def transcribe_fn(audio, *, hints=None, whisper_opts=None):
|
||||
seen["audio"] = audio
|
||||
seen["hints"] = hints
|
||||
seen["whisper_opts"] = whisper_opts
|
||||
return "hello"
|
||||
|
||||
def llm_fn(*, system_prompt, user_prompt, llm_opts=None):
|
||||
seen["system_prompt"] = system_prompt
|
||||
seen["user_prompt"] = user_prompt
|
||||
seen["llm_opts"] = llm_opts
|
||||
return "world"
|
||||
|
||||
lib = PipelineLib(
|
||||
transcribe_fn=transcribe_fn,
|
||||
llm_fn=llm_fn,
|
||||
)
|
||||
|
||||
audio = object()
|
||||
self.assertEqual(
|
||||
lib.transcribe(audio, hints=["Docker"], whisper_opts={"vad_filter": True}),
|
||||
"hello",
|
||||
)
|
||||
self.assertEqual(
|
||||
lib.llm(
|
||||
system_prompt="sys",
|
||||
user_prompt="user",
|
||||
llm_opts={"temperature": 0.2},
|
||||
),
|
||||
"world",
|
||||
)
|
||||
self.assertIs(seen["audio"], audio)
|
||||
self.assertEqual(seen["hints"], ["Docker"])
|
||||
self.assertEqual(seen["whisper_opts"], {"vad_filter": True})
|
||||
self.assertEqual(seen["system_prompt"], "sys")
|
||||
self.assertEqual(seen["user_prompt"], "user")
|
||||
self.assertEqual(seen["llm_opts"], {"temperature": 0.2})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
109
tests/test_pipelines_runtime.py
Normal file
109
tests/test_pipelines_runtime.py
Normal file
|
|
@ -0,0 +1,109 @@
|
|||
import tempfile
|
||||
import textwrap
|
||||
import sys
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
SRC = ROOT / "src"
|
||||
if str(SRC) not in sys.path:
|
||||
sys.path.insert(0, str(SRC))
|
||||
|
||||
from pipelines_runtime import fingerprint, load_bindings
|
||||
|
||||
|
||||
class PipelinesRuntimeTests(unittest.TestCase):
|
||||
def test_missing_file_uses_default_binding(self):
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
path = Path(td) / "pipelines.py"
|
||||
called = {"count": 0}
|
||||
|
||||
def factory():
|
||||
called["count"] += 1
|
||||
|
||||
def handler(_audio, _lib):
|
||||
return "ok"
|
||||
|
||||
return handler
|
||||
|
||||
bindings = load_bindings(
|
||||
path=path,
|
||||
default_hotkey="Cmd+m",
|
||||
default_handler_factory=factory,
|
||||
)
|
||||
|
||||
self.assertEqual(list(bindings.keys()), ["Cmd+m"])
|
||||
self.assertEqual(called["count"], 1)
|
||||
|
||||
def test_loads_module_bindings_and_options(self):
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
path = Path(td) / "pipelines.py"
|
||||
path.write_text(
|
||||
textwrap.dedent(
|
||||
"""
|
||||
def p1(audio, lib):
|
||||
return "one"
|
||||
|
||||
def p2(audio, lib):
|
||||
return "two"
|
||||
|
||||
HOTKEY_PIPELINES = {
|
||||
"Cmd+m": p1,
|
||||
"Ctrl+space": p2,
|
||||
}
|
||||
|
||||
PIPELINE_OPTIONS = {
|
||||
"Ctrl+space": {"failure_policy": "strict"},
|
||||
}
|
||||
"""
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
bindings = load_bindings(
|
||||
path=path,
|
||||
default_hotkey="Cmd+m",
|
||||
default_handler_factory=lambda: (lambda _audio, _lib: "default"),
|
||||
)
|
||||
|
||||
self.assertEqual(sorted(bindings.keys()), ["Cmd+m", "Ctrl+space"])
|
||||
self.assertEqual(bindings["Cmd+m"].options.failure_policy, "best_effort")
|
||||
self.assertEqual(bindings["Ctrl+space"].options.failure_policy, "strict")
|
||||
|
||||
def test_invalid_options_fail(self):
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
path = Path(td) / "pipelines.py"
|
||||
path.write_text(
|
||||
textwrap.dedent(
|
||||
"""
|
||||
def p(audio, lib):
|
||||
return "ok"
|
||||
HOTKEY_PIPELINES = {"Cmd+m": p}
|
||||
PIPELINE_OPTIONS = {"Cmd+m": {"failure_policy": "invalid"}}
|
||||
"""
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "failure_policy"):
|
||||
load_bindings(
|
||||
path=path,
|
||||
default_hotkey="Cmd+m",
|
||||
default_handler_factory=lambda: (lambda _audio, _lib: "default"),
|
||||
)
|
||||
|
||||
def test_fingerprint_changes_with_content(self):
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
path = Path(td) / "pipelines.py"
|
||||
self.assertIsNone(fingerprint(path))
|
||||
path.write_text("HOTKEY_PIPELINES = {}", encoding="utf-8")
|
||||
first = fingerprint(path)
|
||||
path.write_text("HOTKEY_PIPELINES = {'Cmd+m': lambda a, l: ''}", encoding="utf-8")
|
||||
second = fingerprint(path)
|
||||
self.assertIsNotNone(first)
|
||||
self.assertIsNotNone(second)
|
||||
self.assertNotEqual(first, second)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
@ -7,18 +7,17 @@ SRC = ROOT / "src"
|
|||
if str(SRC) not in sys.path:
|
||||
sys.path.insert(0, str(SRC))
|
||||
|
||||
from config import DomainInferenceConfig, VocabularyConfig, VocabularyReplacement
|
||||
from vocabulary import DOMAIN_GENERAL, VocabularyEngine
|
||||
from config import VocabularyConfig, VocabularyReplacement
|
||||
from vocabulary import VocabularyEngine
|
||||
|
||||
|
||||
class VocabularyEngineTests(unittest.TestCase):
|
||||
def _engine(self, replacements=None, terms=None, domain_enabled=True):
|
||||
def _engine(self, replacements=None, terms=None):
|
||||
vocab = VocabularyConfig(
|
||||
replacements=replacements or [],
|
||||
terms=terms or [],
|
||||
)
|
||||
domain = DomainInferenceConfig(enabled=domain_enabled)
|
||||
return VocabularyEngine(vocab, domain)
|
||||
return VocabularyEngine(vocab)
|
||||
|
||||
def test_boundary_aware_replacement(self):
|
||||
engine = self._engine(
|
||||
|
|
@ -50,27 +49,5 @@ class VocabularyEngineTests(unittest.TestCase):
|
|||
self.assertLessEqual(len(hotwords), 1024)
|
||||
self.assertLessEqual(len(prompt), 600)
|
||||
|
||||
def test_domain_inference_general_fallback(self):
|
||||
engine = self._engine()
|
||||
result = engine.infer_domain("please call me later")
|
||||
|
||||
self.assertEqual(result.name, DOMAIN_GENERAL)
|
||||
self.assertEqual(result.confidence, 0.0)
|
||||
|
||||
def test_domain_inference_for_technical_text(self):
|
||||
engine = self._engine(terms=["Docker", "Systemd"])
|
||||
result = engine.infer_domain("restart Docker and systemd service on prod")
|
||||
|
||||
self.assertNotEqual(result.name, DOMAIN_GENERAL)
|
||||
self.assertGreater(result.confidence, 0.0)
|
||||
|
||||
def test_domain_inference_can_be_disabled(self):
|
||||
engine = self._engine(domain_enabled=False)
|
||||
result = engine.infer_domain("please restart docker")
|
||||
|
||||
self.assertEqual(result.name, DOMAIN_GENERAL)
|
||||
self.assertEqual(result.confidence, 0.0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue