Harden recorder shutdown with guaranteed stream close
This commit is contained in:
parent
48d7460f57
commit
8b3532f2ca
2 changed files with 94 additions and 2 deletions
|
|
@ -1,4 +1,5 @@
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
import logging
|
||||||
from typing import Any, Iterable
|
from typing import Any, Iterable
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
@ -70,8 +71,19 @@ def start_recording(input_spec: str | int | None) -> tuple[Any, RecordResult]:
|
||||||
|
|
||||||
def stop_recording(stream: Any, record: RecordResult) -> np.ndarray:
|
def stop_recording(stream: Any, record: RecordResult) -> np.ndarray:
|
||||||
if stream:
|
if stream:
|
||||||
stream.stop()
|
stop_error = None
|
||||||
stream.close()
|
try:
|
||||||
|
stream.stop()
|
||||||
|
except Exception as exc:
|
||||||
|
stop_error = exc
|
||||||
|
try:
|
||||||
|
stream.close()
|
||||||
|
except Exception as close_exc:
|
||||||
|
if stop_error is None:
|
||||||
|
raise
|
||||||
|
logging.warning("stream close failed after stop failure: %s", close_exc)
|
||||||
|
if stop_error is not None:
|
||||||
|
raise stop_error
|
||||||
return _flatten_frames(record.frames)
|
return _flatten_frames(record.frames)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
80
tests/test_recorder.py
Normal file
80
tests/test_recorder.py
Normal file
|
|
@ -0,0 +1,80 @@
|
||||||
|
import sys
|
||||||
|
import unittest
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
ROOT = Path(__file__).resolve().parents[1]
|
||||||
|
SRC = ROOT / "src"
|
||||||
|
if str(SRC) not in sys.path:
|
||||||
|
sys.path.insert(0, str(SRC))
|
||||||
|
|
||||||
|
from recorder import RecordResult, stop_recording
|
||||||
|
|
||||||
|
|
||||||
|
class _Stream:
|
||||||
|
def __init__(self, *, stop_exc: Exception | None = None, close_exc: Exception | None = None):
|
||||||
|
self.stop_exc = stop_exc
|
||||||
|
self.close_exc = close_exc
|
||||||
|
self.stop_calls = 0
|
||||||
|
self.close_calls = 0
|
||||||
|
|
||||||
|
def stop(self) -> None:
|
||||||
|
self.stop_calls += 1
|
||||||
|
if self.stop_exc is not None:
|
||||||
|
raise self.stop_exc
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
self.close_calls += 1
|
||||||
|
if self.close_exc is not None:
|
||||||
|
raise self.close_exc
|
||||||
|
|
||||||
|
|
||||||
|
class RecorderTests(unittest.TestCase):
|
||||||
|
def test_stop_recording_closes_stream_when_stop_raises(self):
|
||||||
|
stream = _Stream(stop_exc=RuntimeError("stop boom"))
|
||||||
|
record = RecordResult()
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(RuntimeError, "stop boom"):
|
||||||
|
stop_recording(stream, record)
|
||||||
|
|
||||||
|
self.assertEqual(stream.stop_calls, 1)
|
||||||
|
self.assertEqual(stream.close_calls, 1)
|
||||||
|
|
||||||
|
def test_stop_recording_raises_close_error_when_stop_succeeds(self):
|
||||||
|
stream = _Stream(close_exc=RuntimeError("close boom"))
|
||||||
|
record = RecordResult()
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(RuntimeError, "close boom"):
|
||||||
|
stop_recording(stream, record)
|
||||||
|
|
||||||
|
self.assertEqual(stream.stop_calls, 1)
|
||||||
|
self.assertEqual(stream.close_calls, 1)
|
||||||
|
|
||||||
|
def test_stop_recording_raises_stop_error_when_both_fail(self):
|
||||||
|
stream = _Stream(
|
||||||
|
stop_exc=RuntimeError("stop boom"),
|
||||||
|
close_exc=RuntimeError("close boom"),
|
||||||
|
)
|
||||||
|
record = RecordResult()
|
||||||
|
|
||||||
|
with self.assertLogs(level="WARNING") as logs:
|
||||||
|
with self.assertRaisesRegex(RuntimeError, "stop boom"):
|
||||||
|
stop_recording(stream, record)
|
||||||
|
|
||||||
|
self.assertEqual(stream.stop_calls, 1)
|
||||||
|
self.assertEqual(stream.close_calls, 1)
|
||||||
|
self.assertTrue(
|
||||||
|
any("stream close failed after stop failure: close boom" in line for line in logs.output)
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_stop_recording_with_no_stream_flattens_frames(self):
|
||||||
|
record = RecordResult(frames=[np.array([[0.2], [0.4]], dtype=np.float32)])
|
||||||
|
|
||||||
|
audio = stop_recording(None, record)
|
||||||
|
|
||||||
|
np.testing.assert_allclose(audio, np.array([0.2, 0.4], dtype=np.float32))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Loading…
Add table
Add a link
Reference in a new issue