92 lines
2.8 KiB
Python
92 lines
2.8 KiB
Python
import sys
|
|
import unittest
|
|
from pathlib import Path
|
|
|
|
ROOT = Path(__file__).resolve().parents[1]
|
|
SRC = ROOT / "src"
|
|
if str(SRC) not in sys.path:
|
|
sys.path.insert(0, str(SRC))
|
|
|
|
from engine import Engine, PipelineBinding, PipelineLib, PipelineOptions
|
|
|
|
|
|
class EngineTests(unittest.TestCase):
|
|
def test_best_effort_pipeline_failure_returns_empty_string(self):
|
|
lib = PipelineLib(
|
|
transcribe_fn=lambda *_args, **_kwargs: "text",
|
|
llm_fn=lambda **_kwargs: "result",
|
|
)
|
|
engine = Engine(lib)
|
|
|
|
def failing_pipeline(_audio, _lib):
|
|
raise RuntimeError("boom")
|
|
|
|
binding = PipelineBinding(
|
|
hotkey="Cmd+m",
|
|
handler=failing_pipeline,
|
|
options=PipelineOptions(failure_policy="best_effort"),
|
|
)
|
|
out = engine.run(binding, object())
|
|
self.assertEqual(out, "")
|
|
|
|
def test_strict_pipeline_failure_raises(self):
|
|
lib = PipelineLib(
|
|
transcribe_fn=lambda *_args, **_kwargs: "text",
|
|
llm_fn=lambda **_kwargs: "result",
|
|
)
|
|
engine = Engine(lib)
|
|
|
|
def failing_pipeline(_audio, _lib):
|
|
raise RuntimeError("boom")
|
|
|
|
binding = PipelineBinding(
|
|
hotkey="Cmd+m",
|
|
handler=failing_pipeline,
|
|
options=PipelineOptions(failure_policy="strict"),
|
|
)
|
|
with self.assertRaisesRegex(RuntimeError, "boom"):
|
|
engine.run(binding, object())
|
|
|
|
def test_pipeline_lib_forwards_arguments(self):
|
|
seen = {}
|
|
|
|
def transcribe_fn(audio, *, hints=None, whisper_opts=None):
|
|
seen["audio"] = audio
|
|
seen["hints"] = hints
|
|
seen["whisper_opts"] = whisper_opts
|
|
return "hello"
|
|
|
|
def llm_fn(*, system_prompt, user_prompt, llm_opts=None):
|
|
seen["system_prompt"] = system_prompt
|
|
seen["user_prompt"] = user_prompt
|
|
seen["llm_opts"] = llm_opts
|
|
return "world"
|
|
|
|
lib = PipelineLib(
|
|
transcribe_fn=transcribe_fn,
|
|
llm_fn=llm_fn,
|
|
)
|
|
|
|
audio = object()
|
|
self.assertEqual(
|
|
lib.transcribe(audio, hints=["Docker"], whisper_opts={"vad_filter": True}),
|
|
"hello",
|
|
)
|
|
self.assertEqual(
|
|
lib.llm(
|
|
system_prompt="sys",
|
|
user_prompt="user",
|
|
llm_opts={"temperature": 0.2},
|
|
),
|
|
"world",
|
|
)
|
|
self.assertIs(seen["audio"], audio)
|
|
self.assertEqual(seen["hints"], ["Docker"])
|
|
self.assertEqual(seen["whisper_opts"], {"vad_filter": True})
|
|
self.assertEqual(seen["system_prompt"], "sys")
|
|
self.assertEqual(seen["user_prompt"], "user")
|
|
self.assertEqual(seen["llm_opts"], {"temperature": 0.2})
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|