Add pipeline engine and remove legacy compatibility paths
This commit is contained in:
parent
3bc473262d
commit
e221d49020
18 changed files with 1523 additions and 399 deletions
92
tests/test_engine.py
Normal file
92
tests/test_engine.py
Normal file
|
|
@ -0,0 +1,92 @@
|
|||
import sys
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
SRC = ROOT / "src"
|
||||
if str(SRC) not in sys.path:
|
||||
sys.path.insert(0, str(SRC))
|
||||
|
||||
from engine import Engine, PipelineBinding, PipelineLib, PipelineOptions
|
||||
|
||||
|
||||
class EngineTests(unittest.TestCase):
|
||||
def test_best_effort_pipeline_failure_returns_empty_string(self):
|
||||
lib = PipelineLib(
|
||||
transcribe_fn=lambda *_args, **_kwargs: "text",
|
||||
llm_fn=lambda **_kwargs: "result",
|
||||
)
|
||||
engine = Engine(lib)
|
||||
|
||||
def failing_pipeline(_audio, _lib):
|
||||
raise RuntimeError("boom")
|
||||
|
||||
binding = PipelineBinding(
|
||||
hotkey="Cmd+m",
|
||||
handler=failing_pipeline,
|
||||
options=PipelineOptions(failure_policy="best_effort"),
|
||||
)
|
||||
out = engine.run(binding, object())
|
||||
self.assertEqual(out, "")
|
||||
|
||||
def test_strict_pipeline_failure_raises(self):
|
||||
lib = PipelineLib(
|
||||
transcribe_fn=lambda *_args, **_kwargs: "text",
|
||||
llm_fn=lambda **_kwargs: "result",
|
||||
)
|
||||
engine = Engine(lib)
|
||||
|
||||
def failing_pipeline(_audio, _lib):
|
||||
raise RuntimeError("boom")
|
||||
|
||||
binding = PipelineBinding(
|
||||
hotkey="Cmd+m",
|
||||
handler=failing_pipeline,
|
||||
options=PipelineOptions(failure_policy="strict"),
|
||||
)
|
||||
with self.assertRaisesRegex(RuntimeError, "boom"):
|
||||
engine.run(binding, object())
|
||||
|
||||
def test_pipeline_lib_forwards_arguments(self):
|
||||
seen = {}
|
||||
|
||||
def transcribe_fn(audio, *, hints=None, whisper_opts=None):
|
||||
seen["audio"] = audio
|
||||
seen["hints"] = hints
|
||||
seen["whisper_opts"] = whisper_opts
|
||||
return "hello"
|
||||
|
||||
def llm_fn(*, system_prompt, user_prompt, llm_opts=None):
|
||||
seen["system_prompt"] = system_prompt
|
||||
seen["user_prompt"] = user_prompt
|
||||
seen["llm_opts"] = llm_opts
|
||||
return "world"
|
||||
|
||||
lib = PipelineLib(
|
||||
transcribe_fn=transcribe_fn,
|
||||
llm_fn=llm_fn,
|
||||
)
|
||||
|
||||
audio = object()
|
||||
self.assertEqual(
|
||||
lib.transcribe(audio, hints=["Docker"], whisper_opts={"vad_filter": True}),
|
||||
"hello",
|
||||
)
|
||||
self.assertEqual(
|
||||
lib.llm(
|
||||
system_prompt="sys",
|
||||
user_prompt="user",
|
||||
llm_opts={"temperature": 0.2},
|
||||
),
|
||||
"world",
|
||||
)
|
||||
self.assertIs(seen["audio"], audio)
|
||||
self.assertEqual(seen["hints"], ["Docker"])
|
||||
self.assertEqual(seen["whisper_opts"], {"vad_filter": True})
|
||||
self.assertEqual(seen["system_prompt"], "sys")
|
||||
self.assertEqual(seen["user_prompt"], "user")
|
||||
self.assertEqual(seen["llm_opts"], {"temperature": 0.2})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Loading…
Add table
Add a link
Reference in a new issue