aman/src/recorder.py

92 lines
2.5 KiB
Python

import tempfile
from dataclasses import dataclass, field
from pathlib import Path
from typing import Iterable
import numpy as np
import sounddevice as sd # type: ignore[import-not-found]
import soundfile as sf # type: ignore[import-not-found]
@dataclass
class RecordResult:
wav_path: str
temp_dir: str
frames: list[np.ndarray] = field(default_factory=list)
samplerate: int = 16000
channels: int = 1
dtype: str = "int16"
def list_input_devices() -> list[dict]:
devices = []
for idx, info in enumerate(sd.query_devices()):
if info.get("max_input_channels", 0) > 0:
devices.append({"index": idx, "name": info.get("name", "")})
return devices
def default_input_device() -> int | None:
default = sd.default.device
if isinstance(default, (tuple, list)) and default:
return default[0]
if isinstance(default, int):
return default
return None
def resolve_input_device(spec: str | int | None) -> int | None:
if spec is None:
return None
if isinstance(spec, int):
return spec
text = str(spec).strip()
if not text:
return None
if text.isdigit():
return int(text)
lowered = text.lower()
for device in list_input_devices():
name = (device.get("name") or "").lower()
if lowered in name:
return int(device["index"])
return None
def start_recording(input_spec: str | int | None) -> tuple[sd.InputStream, RecordResult]:
tmpdir = tempfile.mkdtemp(prefix="lel-")
wav = str(Path(tmpdir) / "mic.wav")
record = RecordResult(wav_path=wav, temp_dir=tmpdir)
device = resolve_input_device(input_spec)
def callback(indata, _frames, _time, _status):
record.frames.append(indata.copy())
stream = sd.InputStream(
samplerate=record.samplerate,
channels=record.channels,
dtype=record.dtype,
device=device,
callback=callback,
)
stream.start()
return stream, record
def stop_recording(stream: sd.InputStream, record: RecordResult) -> None:
if stream:
stream.stop()
stream.close()
_write_wav(record)
def _write_wav(record: RecordResult) -> None:
data = _flatten_frames(record.frames)
sf.write(record.wav_path, data, record.samplerate, subtype="PCM_16")
def _flatten_frames(frames: Iterable[np.ndarray]) -> np.ndarray:
frames = list(frames)
if not frames:
return np.zeros((0, 1), dtype=np.int16)
return np.concatenate(frames, axis=0)