aman/src/recorder.py

83 lines
2.2 KiB
Python

from dataclasses import dataclass, field
from typing import Iterable
import numpy as np
import sounddevice as sd # type: ignore[import-not-found]
@dataclass
class RecordResult:
frames: list[np.ndarray] = field(default_factory=list)
samplerate: int = 16000
channels: int = 1
dtype: str = "float32"
def list_input_devices() -> list[dict]:
devices = []
for idx, info in enumerate(sd.query_devices()):
if info.get("max_input_channels", 0) > 0:
devices.append({"index": idx, "name": info.get("name", "")})
return devices
def default_input_device() -> int | None:
default = sd.default.device
if isinstance(default, (tuple, list)) and default:
return default[0]
if isinstance(default, int):
return default
return None
def resolve_input_device(spec: str | int | None) -> int | None:
if spec is None:
return None
if isinstance(spec, int):
return spec
text = str(spec).strip()
if not text:
return None
if text.isdigit():
return int(text)
lowered = text.lower()
for device in list_input_devices():
name = (device.get("name") or "").lower()
if lowered in name:
return int(device["index"])
return None
def start_recording(input_spec: str | int | None) -> tuple[sd.InputStream, RecordResult]:
record = RecordResult()
device = resolve_input_device(input_spec)
def callback(indata, _frames, _time, _status):
record.frames.append(indata.copy())
stream = sd.InputStream(
samplerate=record.samplerate,
channels=record.channels,
dtype=record.dtype,
device=device,
callback=callback,
)
stream.start()
return stream, record
def stop_recording(stream: sd.InputStream, record: RecordResult) -> np.ndarray:
if stream:
stream.stop()
stream.close()
return _flatten_frames(record.frames)
def _flatten_frames(frames: Iterable[np.ndarray]) -> np.ndarray:
frames = list(frames)
if not frames:
return np.zeros((0,), dtype=np.float32)
data = np.concatenate(frames, axis=0)
if data.ndim > 1:
data = np.squeeze(data, axis=-1)
return np.asarray(data, dtype=np.float32).reshape(-1)