From 8b3532f2ca3ecdfd388d3710a915e494fb5492ad Mon Sep 17 00:00:00 2001 From: Thales Maciel Date: Thu, 26 Feb 2026 16:37:36 -0300 Subject: [PATCH] Harden recorder shutdown with guaranteed stream close --- src/recorder.py | 16 +++++++-- tests/test_recorder.py | 80 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 94 insertions(+), 2 deletions(-) create mode 100644 tests/test_recorder.py diff --git a/src/recorder.py b/src/recorder.py index 00aa51b..fc4501a 100644 --- a/src/recorder.py +++ b/src/recorder.py @@ -1,4 +1,5 @@ from dataclasses import dataclass, field +import logging from typing import Any, Iterable import numpy as np @@ -70,8 +71,19 @@ def start_recording(input_spec: str | int | None) -> tuple[Any, RecordResult]: def stop_recording(stream: Any, record: RecordResult) -> np.ndarray: if stream: - stream.stop() - stream.close() + stop_error = None + try: + stream.stop() + except Exception as exc: + stop_error = exc + try: + stream.close() + except Exception as close_exc: + if stop_error is None: + raise + logging.warning("stream close failed after stop failure: %s", close_exc) + if stop_error is not None: + raise stop_error return _flatten_frames(record.frames) diff --git a/tests/test_recorder.py b/tests/test_recorder.py new file mode 100644 index 0000000..e6d98fb --- /dev/null +++ b/tests/test_recorder.py @@ -0,0 +1,80 @@ +import sys +import unittest +from pathlib import Path + +import numpy as np + +ROOT = Path(__file__).resolve().parents[1] +SRC = ROOT / "src" +if str(SRC) not in sys.path: + sys.path.insert(0, str(SRC)) + +from recorder import RecordResult, stop_recording + + +class _Stream: + def __init__(self, *, stop_exc: Exception | None = None, close_exc: Exception | None = None): + self.stop_exc = stop_exc + self.close_exc = close_exc + self.stop_calls = 0 + self.close_calls = 0 + + def stop(self) -> None: + self.stop_calls += 1 + if self.stop_exc is not None: + raise self.stop_exc + + def close(self) -> None: + self.close_calls += 1 + if self.close_exc is not None: + raise self.close_exc + + +class RecorderTests(unittest.TestCase): + def test_stop_recording_closes_stream_when_stop_raises(self): + stream = _Stream(stop_exc=RuntimeError("stop boom")) + record = RecordResult() + + with self.assertRaisesRegex(RuntimeError, "stop boom"): + stop_recording(stream, record) + + self.assertEqual(stream.stop_calls, 1) + self.assertEqual(stream.close_calls, 1) + + def test_stop_recording_raises_close_error_when_stop_succeeds(self): + stream = _Stream(close_exc=RuntimeError("close boom")) + record = RecordResult() + + with self.assertRaisesRegex(RuntimeError, "close boom"): + stop_recording(stream, record) + + self.assertEqual(stream.stop_calls, 1) + self.assertEqual(stream.close_calls, 1) + + def test_stop_recording_raises_stop_error_when_both_fail(self): + stream = _Stream( + stop_exc=RuntimeError("stop boom"), + close_exc=RuntimeError("close boom"), + ) + record = RecordResult() + + with self.assertLogs(level="WARNING") as logs: + with self.assertRaisesRegex(RuntimeError, "stop boom"): + stop_recording(stream, record) + + self.assertEqual(stream.stop_calls, 1) + self.assertEqual(stream.close_calls, 1) + self.assertTrue( + any("stream close failed after stop failure: close boom" in line for line in logs.output) + ) + + def test_stop_recording_with_no_stream_flattens_frames(self): + record = RecordResult(frames=[np.array([[0.2], [0.4]], dtype=np.float32)]) + + audio = stop_recording(None, record) + + np.testing.assert_allclose(audio, np.array([0.2, 0.4], dtype=np.float32)) + + +if __name__ == "__main__": + unittest.main()