Use in-memory audio for STT

This commit is contained in:
Thales Maciel 2026-02-24 11:48:02 -03:00
parent 861f199dea
commit ebba452268
No known key found for this signature in database
GPG key ID: 33112E6833C34679
5 changed files with 17 additions and 50 deletions

View file

@ -103,7 +103,7 @@ class Daemon:
self.proc = proc
self.record = record
self.state = State.RECORDING
logging.info("recording started (%s)", record.wav_path)
logging.info("recording started")
if self.timer:
self.timer.cancel()
self.timer = threading.Timer(RECORD_TIMEOUT_SEC, self._timeout_stop)
@ -132,13 +132,13 @@ class Daemon:
logging.info("stopping recording (user)")
try:
stop_recording(proc, record)
audio = stop_recording(proc, record)
except Exception as exc:
logging.error("record stop failed: %s", exc)
self.set_state(State.IDLE)
return
if not Path(record.wav_path).exists():
if audio.size == 0:
logging.error("no audio captured")
self.set_state(State.IDLE)
return
@ -146,7 +146,7 @@ class Daemon:
try:
self.set_state(State.STT)
logging.info("stt started")
text = self._transcribe(record.wav_path)
text = self._transcribe(audio)
except Exception as exc:
logging.error("stt failed: %s", exc)
self.set_state(State.IDLE)
@ -199,8 +199,8 @@ class Daemon:
self.state = State.STT
threading.Thread(target=self._stop_and_process, daemon=True).start()
def _transcribe(self, wav_path: str) -> str:
segments, _info = self.model.transcribe(wav_path, language=STT_LANGUAGE, vad_filter=True)
def _transcribe(self, audio) -> str:
segments, _info = self.model.transcribe(audio, language=STT_LANGUAGE, vad_filter=True)
parts = []
for seg in segments:
text = (seg.text or "").strip()

View file

@ -1,21 +1,16 @@
import tempfile
from dataclasses import dataclass, field
from pathlib import Path
from typing import Iterable
import numpy as np
import sounddevice as sd # type: ignore[import-not-found]
import soundfile as sf # type: ignore[import-not-found]
@dataclass
class RecordResult:
wav_path: str
temp_dir: str
frames: list[np.ndarray] = field(default_factory=list)
samplerate: int = 16000
channels: int = 1
dtype: str = "int16"
dtype: str = "float32"
def list_input_devices() -> list[dict]:
@ -54,9 +49,7 @@ def resolve_input_device(spec: str | int | None) -> int | None:
def start_recording(input_spec: str | int | None) -> tuple[sd.InputStream, RecordResult]:
tmpdir = tempfile.mkdtemp(prefix="lel-")
wav = str(Path(tmpdir) / "mic.wav")
record = RecordResult(wav_path=wav, temp_dir=tmpdir)
record = RecordResult()
device = resolve_input_device(input_spec)
def callback(indata, _frames, _time, _status):
@ -73,20 +66,18 @@ def start_recording(input_spec: str | int | None) -> tuple[sd.InputStream, Recor
return stream, record
def stop_recording(stream: sd.InputStream, record: RecordResult) -> None:
def stop_recording(stream: sd.InputStream, record: RecordResult) -> np.ndarray:
if stream:
stream.stop()
stream.close()
_write_wav(record)
def _write_wav(record: RecordResult) -> None:
data = _flatten_frames(record.frames)
sf.write(record.wav_path, data, record.samplerate, subtype="PCM_16")
return _flatten_frames(record.frames)
def _flatten_frames(frames: Iterable[np.ndarray]) -> np.ndarray:
frames = list(frames)
if not frames:
return np.zeros((0, 1), dtype=np.int16)
return np.concatenate(frames, axis=0)
return np.zeros((0,), dtype=np.float32)
data = np.concatenate(frames, axis=0)
if data.ndim > 1:
data = np.squeeze(data, axis=-1)
return np.asarray(data, dtype=np.float32).reshape(-1)