92 lines
2.5 KiB
Python
92 lines
2.5 KiB
Python
import tempfile
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
from typing import Iterable
|
|
|
|
import numpy as np
|
|
import sounddevice as sd # type: ignore[import-not-found]
|
|
import soundfile as sf # type: ignore[import-not-found]
|
|
|
|
|
|
@dataclass
|
|
class RecordResult:
|
|
wav_path: str
|
|
temp_dir: str
|
|
frames: list[np.ndarray] = field(default_factory=list)
|
|
samplerate: int = 16000
|
|
channels: int = 1
|
|
dtype: str = "int16"
|
|
|
|
|
|
def list_input_devices() -> list[dict]:
|
|
devices = []
|
|
for idx, info in enumerate(sd.query_devices()):
|
|
if info.get("max_input_channels", 0) > 0:
|
|
devices.append({"index": idx, "name": info.get("name", "")})
|
|
return devices
|
|
|
|
|
|
def default_input_device() -> int | None:
|
|
default = sd.default.device
|
|
if isinstance(default, (tuple, list)) and default:
|
|
return default[0]
|
|
if isinstance(default, int):
|
|
return default
|
|
return None
|
|
|
|
|
|
def resolve_input_device(spec: str | int | None) -> int | None:
|
|
if spec is None:
|
|
return None
|
|
if isinstance(spec, int):
|
|
return spec
|
|
text = str(spec).strip()
|
|
if not text:
|
|
return None
|
|
if text.isdigit():
|
|
return int(text)
|
|
lowered = text.lower()
|
|
for device in list_input_devices():
|
|
name = (device.get("name") or "").lower()
|
|
if lowered in name:
|
|
return int(device["index"])
|
|
return None
|
|
|
|
|
|
def start_recording(input_spec: str | int | None) -> tuple[sd.InputStream, RecordResult]:
|
|
tmpdir = tempfile.mkdtemp(prefix="lel-")
|
|
wav = str(Path(tmpdir) / "mic.wav")
|
|
record = RecordResult(wav_path=wav, temp_dir=tmpdir)
|
|
device = resolve_input_device(input_spec)
|
|
|
|
def callback(indata, _frames, _time, _status):
|
|
record.frames.append(indata.copy())
|
|
|
|
stream = sd.InputStream(
|
|
samplerate=record.samplerate,
|
|
channels=record.channels,
|
|
dtype=record.dtype,
|
|
device=device,
|
|
callback=callback,
|
|
)
|
|
stream.start()
|
|
return stream, record
|
|
|
|
|
|
def stop_recording(stream: sd.InputStream, record: RecordResult) -> None:
|
|
if stream:
|
|
stream.stop()
|
|
stream.close()
|
|
_write_wav(record)
|
|
|
|
|
|
def _write_wav(record: RecordResult) -> None:
|
|
data = _flatten_frames(record.frames)
|
|
sf.write(record.wav_path, data, record.samplerate, subtype="PCM_16")
|
|
|
|
|
|
def _flatten_frames(frames: Iterable[np.ndarray]) -> np.ndarray:
|
|
frames = list(frames)
|
|
if not frames:
|
|
return np.zeros((0, 1), dtype=np.int16)
|
|
return np.concatenate(frames, axis=0)
|