aman/src/recorder.py

107 lines
2.9 KiB
Python

from dataclasses import dataclass, field
import logging
from typing import Any, Iterable
import numpy as np
@dataclass
class RecordResult:
frames: list[np.ndarray] = field(default_factory=list)
samplerate: int = 16000
channels: int = 1
dtype: str = "float32"
def list_input_devices() -> list[dict]:
sd = _sounddevice()
devices = []
for idx, info in enumerate(sd.query_devices()):
if info.get("max_input_channels", 0) > 0:
devices.append({"index": idx, "name": info.get("name", "")})
return devices
def default_input_device() -> int | None:
sd = _sounddevice()
default = sd.default.device
if isinstance(default, (tuple, list)) and default:
return default[0]
if isinstance(default, int):
return default
return None
def resolve_input_device(spec: str | int | None) -> int | None:
if spec is None:
return None
if isinstance(spec, int):
return spec
text = str(spec).strip()
if not text:
return None
if text.isdigit():
return int(text)
lowered = text.lower()
for device in list_input_devices():
name = (device.get("name") or "").lower()
if lowered in name:
return int(device["index"])
return None
def start_recording(input_spec: str | int | None) -> tuple[Any, RecordResult]:
sd = _sounddevice()
record = RecordResult()
device = resolve_input_device(input_spec)
def callback(indata, _frames, _time, _status):
record.frames.append(indata.copy())
stream = sd.InputStream(
samplerate=record.samplerate,
channels=record.channels,
dtype=record.dtype,
device=device,
callback=callback,
)
stream.start()
return stream, record
def stop_recording(stream: Any, record: RecordResult) -> np.ndarray:
if stream:
stop_error = None
try:
stream.stop()
except Exception as exc:
stop_error = exc
try:
stream.close()
except Exception as close_exc:
if stop_error is None:
raise
logging.warning("stream close failed after stop failure: %s", close_exc)
if stop_error is not None:
raise stop_error
return _flatten_frames(record.frames)
def _sounddevice():
try:
import sounddevice as sd # type: ignore[import-not-found]
except ModuleNotFoundError as exc:
raise RuntimeError(
"sounddevice is not installed; install dependencies with `uv sync --extra x11`"
) from exc
return sd
def _flatten_frames(frames: Iterable[np.ndarray]) -> np.ndarray:
frames = list(frames)
if not frames:
return np.zeros((0,), dtype=np.float32)
data = np.concatenate(frames, axis=0)
if data.ndim > 1:
data = np.squeeze(data, axis=-1)
return np.asarray(data, dtype=np.float32).reshape(-1)