107 lines
2.9 KiB
Python
107 lines
2.9 KiB
Python
from dataclasses import dataclass, field
|
|
import logging
|
|
from typing import Any, Iterable
|
|
|
|
import numpy as np
|
|
|
|
|
|
@dataclass
|
|
class RecordResult:
|
|
frames: list[np.ndarray] = field(default_factory=list)
|
|
samplerate: int = 16000
|
|
channels: int = 1
|
|
dtype: str = "float32"
|
|
|
|
|
|
def list_input_devices() -> list[dict]:
|
|
sd = _sounddevice()
|
|
devices = []
|
|
for idx, info in enumerate(sd.query_devices()):
|
|
if info.get("max_input_channels", 0) > 0:
|
|
devices.append({"index": idx, "name": info.get("name", "")})
|
|
return devices
|
|
|
|
|
|
def default_input_device() -> int | None:
|
|
sd = _sounddevice()
|
|
default = sd.default.device
|
|
if isinstance(default, (tuple, list)) and default:
|
|
return default[0]
|
|
if isinstance(default, int):
|
|
return default
|
|
return None
|
|
|
|
|
|
def resolve_input_device(spec: str | int | None) -> int | None:
|
|
if spec is None:
|
|
return None
|
|
if isinstance(spec, int):
|
|
return spec
|
|
text = str(spec).strip()
|
|
if not text:
|
|
return None
|
|
if text.isdigit():
|
|
return int(text)
|
|
lowered = text.lower()
|
|
for device in list_input_devices():
|
|
name = (device.get("name") or "").lower()
|
|
if lowered in name:
|
|
return int(device["index"])
|
|
return None
|
|
|
|
|
|
def start_recording(input_spec: str | int | None) -> tuple[Any, RecordResult]:
|
|
sd = _sounddevice()
|
|
record = RecordResult()
|
|
device = resolve_input_device(input_spec)
|
|
|
|
def callback(indata, _frames, _time, _status):
|
|
record.frames.append(indata.copy())
|
|
|
|
stream = sd.InputStream(
|
|
samplerate=record.samplerate,
|
|
channels=record.channels,
|
|
dtype=record.dtype,
|
|
device=device,
|
|
callback=callback,
|
|
)
|
|
stream.start()
|
|
return stream, record
|
|
|
|
|
|
def stop_recording(stream: Any, record: RecordResult) -> np.ndarray:
|
|
if stream:
|
|
stop_error = None
|
|
try:
|
|
stream.stop()
|
|
except Exception as exc:
|
|
stop_error = exc
|
|
try:
|
|
stream.close()
|
|
except Exception as close_exc:
|
|
if stop_error is None:
|
|
raise
|
|
logging.warning("stream close failed after stop failure: %s", close_exc)
|
|
if stop_error is not None:
|
|
raise stop_error
|
|
return _flatten_frames(record.frames)
|
|
|
|
|
|
def _sounddevice():
|
|
try:
|
|
import sounddevice as sd # type: ignore[import-not-found]
|
|
except ModuleNotFoundError as exc:
|
|
raise RuntimeError(
|
|
"sounddevice is not installed; install dependencies with `uv sync --extra x11`"
|
|
) from exc
|
|
return sd
|
|
|
|
|
|
def _flatten_frames(frames: Iterable[np.ndarray]) -> np.ndarray:
|
|
frames = list(frames)
|
|
if not frames:
|
|
return np.zeros((0,), dtype=np.float32)
|
|
data = np.concatenate(frames, axis=0)
|
|
if data.ndim > 1:
|
|
data = np.squeeze(data, axis=-1)
|
|
return np.asarray(data, dtype=np.float32).reshape(-1)
|