diff --git a/src/aiprocess.py b/src/aiprocess.py index a098bc4..609677d 100644 --- a/src/aiprocess.py +++ b/src/aiprocess.py @@ -69,6 +69,7 @@ class LlamaProcessor: lang: str = "en", *, dictionary_context: str = "", + profile: str = "default", ) -> str: request_payload: dict[str, Any] = { "language": lang, @@ -87,6 +88,7 @@ class LlamaProcessor: } if _supports_response_format(self.client.create_chat_completion): kwargs["response_format"] = {"type": "json_object"} + kwargs.update(_profile_generation_kwargs(self.client.create_chat_completion, profile)) response = self.client.create_chat_completion(**kwargs) return _extract_cleaned_text(response) @@ -205,11 +207,25 @@ def _extract_cleaned_text(payload: Any) -> str: def _supports_response_format(chat_completion: Callable[..., Any]) -> bool: + return _supports_parameter(chat_completion, "response_format") + + +def _supports_parameter(callable_obj: Callable[..., Any], parameter: str) -> bool: try: - signature = inspect.signature(chat_completion) + signature = inspect.signature(callable_obj) except (TypeError, ValueError): return False - return "response_format" in signature.parameters + return parameter in signature.parameters + + +def _profile_generation_kwargs(chat_completion: Callable[..., Any], profile: str) -> dict[str, Any]: + normalized = (profile or "default").strip().lower() + if normalized != "fast": + return {} + if not _supports_parameter(chat_completion, "max_tokens"): + return {} + # Faster profile trades completion depth for lower latency. + return {"max_tokens": 192} def _llama_log_callback_factory(verbose: bool) -> Callable: diff --git a/src/aman.py b/src/aman.py index fda7edc..49880c9 100755 --- a/src/aman.py +++ b/src/aman.py @@ -233,6 +233,7 @@ class Daemon: text, lang=STT_LANGUAGE, dictionary_context=self.vocabulary.build_ai_dictionary_context(), + profile=self.cfg.ux.profile, ) if ai_text and ai_text.strip(): text = ai_text.strip() diff --git a/tests/test_aiprocess.py b/tests/test_aiprocess.py index 05ca6ab..40fc188 100644 --- a/tests/test_aiprocess.py +++ b/tests/test_aiprocess.py @@ -14,6 +14,7 @@ import aiprocess from aiprocess import ( _assert_expected_model_checksum, _extract_cleaned_text, + _profile_generation_kwargs, _supports_response_format, ensure_model, ) @@ -93,6 +94,22 @@ class SupportsResponseFormatTests(unittest.TestCase): self.assertFalse(_supports_response_format(chat_completion)) + def test_fast_profile_sets_max_tokens_when_supported(self): + def chat_completion(*, messages, temperature, max_tokens): + return None + + kwargs = _profile_generation_kwargs(chat_completion, "fast") + + self.assertEqual(kwargs, {"max_tokens": 192}) + + def test_non_fast_profile_does_not_set_generation_overrides(self): + def chat_completion(*, messages, temperature, max_tokens): + return None + + kwargs = _profile_generation_kwargs(chat_completion, "polished") + + self.assertEqual(kwargs, {}) + class ModelChecksumTests(unittest.TestCase): def test_accepts_expected_checksum_case_insensitive(self): diff --git a/tests/test_aman.py b/tests/test_aman.py index f38134d..4091137 100644 --- a/tests/test_aman.py +++ b/tests/test_aman.py @@ -87,7 +87,11 @@ class FakeHintModel: class FakeAIProcessor: + def __init__(self): + self.last_kwargs = {} + def process(self, text, lang="en", **_kwargs): + self.last_kwargs = {"lang": lang, **_kwargs} return text @@ -120,10 +124,12 @@ class DaemonTests(unittest.TestCase): *, cfg: Config | None = None, verbose: bool = False, + ai_processor: FakeAIProcessor | None = None, ) -> aman.Daemon: active_cfg = cfg if cfg is not None else self._config() + active_ai_processor = ai_processor or FakeAIProcessor() with patch("aman._build_whisper_model", return_value=model), patch( - "aman.LlamaProcessor", return_value=FakeAIProcessor() + "aman.LlamaProcessor", return_value=active_ai_processor ): return aman.Daemon(active_cfg, desktop, verbose=verbose) @@ -302,6 +308,31 @@ class DaemonTests(unittest.TestCase): self.assertEqual(stream.stop_calls, 1) self.assertEqual(stream.close_calls, 1) + @patch("aman.stop_audio_recording", return_value=FakeAudio(8)) + @patch("aman.start_audio_recording", return_value=(object(), object())) + def test_ai_processor_receives_active_profile(self, _start_mock, _stop_mock): + desktop = FakeDesktop() + cfg = self._config() + cfg.ux.profile = "fast" + ai_processor = FakeAIProcessor() + daemon = self._build_daemon( + desktop, + FakeModel(text="hello world"), + cfg=cfg, + verbose=False, + ai_processor=ai_processor, + ) + daemon._start_stop_worker = ( + lambda stream, record, trigger, process_audio: daemon._stop_and_process( + stream, record, trigger, process_audio + ) + ) + + daemon.toggle() + daemon.toggle() + + self.assertEqual(ai_processor.last_kwargs.get("profile"), "fast") + class LockTests(unittest.TestCase): def test_lock_rejects_second_instance(self):