aman/tests/test_engine.py

92 lines
2.8 KiB
Python

import sys
import unittest
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
SRC = ROOT / "src"
if str(SRC) not in sys.path:
sys.path.insert(0, str(SRC))
from engine import Engine, PipelineBinding, PipelineLib, PipelineOptions
class EngineTests(unittest.TestCase):
def test_best_effort_pipeline_failure_returns_empty_string(self):
lib = PipelineLib(
transcribe_fn=lambda *_args, **_kwargs: "text",
llm_fn=lambda **_kwargs: "result",
)
engine = Engine(lib)
def failing_pipeline(_audio, _lib):
raise RuntimeError("boom")
binding = PipelineBinding(
hotkey="Cmd+m",
handler=failing_pipeline,
options=PipelineOptions(failure_policy="best_effort"),
)
out = engine.run(binding, object())
self.assertEqual(out, "")
def test_strict_pipeline_failure_raises(self):
lib = PipelineLib(
transcribe_fn=lambda *_args, **_kwargs: "text",
llm_fn=lambda **_kwargs: "result",
)
engine = Engine(lib)
def failing_pipeline(_audio, _lib):
raise RuntimeError("boom")
binding = PipelineBinding(
hotkey="Cmd+m",
handler=failing_pipeline,
options=PipelineOptions(failure_policy="strict"),
)
with self.assertRaisesRegex(RuntimeError, "boom"):
engine.run(binding, object())
def test_pipeline_lib_forwards_arguments(self):
seen = {}
def transcribe_fn(audio, *, hints=None, whisper_opts=None):
seen["audio"] = audio
seen["hints"] = hints
seen["whisper_opts"] = whisper_opts
return "hello"
def llm_fn(*, system_prompt, user_prompt, llm_opts=None):
seen["system_prompt"] = system_prompt
seen["user_prompt"] = user_prompt
seen["llm_opts"] = llm_opts
return "world"
lib = PipelineLib(
transcribe_fn=transcribe_fn,
llm_fn=llm_fn,
)
audio = object()
self.assertEqual(
lib.transcribe(audio, hints=["Docker"], whisper_opts={"vad_filter": True}),
"hello",
)
self.assertEqual(
lib.llm(
system_prompt="sys",
user_prompt="user",
llm_opts={"temperature": 0.2},
),
"world",
)
self.assertIs(seen["audio"], audio)
self.assertEqual(seen["hints"], ["Docker"])
self.assertEqual(seen["whisper_opts"], {"vad_filter": True})
self.assertEqual(seen["system_prompt"], "sys")
self.assertEqual(seen["user_prompt"], "user")
self.assertEqual(seen["llm_opts"], {"temperature": 0.2})
if __name__ == "__main__":
unittest.main()