diff --git a/faster_whisper/assets/silero_vad.onnx b/faster_whisper/assets/silero_vad.onnx index 5c21912..d0ccd9d 100644 Binary files a/faster_whisper/assets/silero_vad.onnx and b/faster_whisper/assets/silero_vad.onnx differ diff --git a/faster_whisper/vad.py b/faster_whisper/vad.py index 487dfa0..99dfb40 100644 --- a/faster_whisper/vad.py +++ b/faster_whisper/vad.py @@ -1,7 +1,6 @@ import bisect import functools import os -import warnings from typing import List, NamedTuple, Optional @@ -25,9 +24,6 @@ class VadOptions(NamedTuple): split aggressively just before max_speech_duration_s. min_silence_duration_ms: In the end of each speech chunk wait for min_silence_duration_ms before separating it - window_size_samples: Audio chunks of window_size_samples size are fed to the silero VAD model. - WARNING! Silero VAD models were trained using 512, 1024, 1536 samples for 16000 sample rate. - Values other than these may affect model performance!! speech_pad_ms: Final speech chunks are padded by speech_pad_ms each side """ @@ -35,7 +31,6 @@ class VadOptions(NamedTuple): min_speech_duration_ms: int = 250 max_speech_duration_s: float = float("inf") min_silence_duration_ms: int = 2000 - window_size_samples: int = 1024 speech_pad_ms: int = 400 @@ -61,15 +56,8 @@ def get_speech_timestamps( min_speech_duration_ms = vad_options.min_speech_duration_ms max_speech_duration_s = vad_options.max_speech_duration_s min_silence_duration_ms = vad_options.min_silence_duration_ms - window_size_samples = vad_options.window_size_samples + window_size_samples = 512 speech_pad_ms = vad_options.speech_pad_ms - - if window_size_samples not in [512, 1024, 1536]: - warnings.warn( - "Unusual window_size_samples! Supported window_size_samples:\n" - " - [512, 1024, 1536] for 16000 sampling_rate" - ) - sampling_rate = 16000 min_speech_samples = sampling_rate * min_speech_duration_ms / 1000 speech_pad_samples = sampling_rate * speech_pad_ms / 1000 @@ -84,14 +72,14 @@ def get_speech_timestamps( audio_length_samples = len(audio) model = get_vad_model() - state = model.get_initial_state(batch_size=1) + state, context = model.get_initial_states(batch_size=1) speech_probs = [] for current_start_sample in range(0, audio_length_samples, window_size_samples): chunk = audio[current_start_sample : current_start_sample + window_size_samples] if len(chunk) < window_size_samples: chunk = np.pad(chunk, (0, int(window_size_samples - len(chunk)))) - speech_prob, state = model(chunk, state, sampling_rate) + speech_prob, state, context = model(chunk, state, context, sampling_rate) speech_probs.append(speech_prob) triggered = False @@ -261,12 +249,12 @@ class SileroVADModel: sess_options=opts, ) - def get_initial_state(self, batch_size: int): - h = np.zeros((2, batch_size, 64), dtype=np.float32) - c = np.zeros((2, batch_size, 64), dtype=np.float32) - return h, c + def get_initial_states(self, batch_size: int): + state = np.zeros((2, batch_size, 128), dtype=np.float32) + context = np.zeros((batch_size, 64), dtype=np.float32) + return state, context - def __call__(self, x, state, sr: int): + def __call__(self, x, state, context, sr: int): if len(x.shape) == 1: x = np.expand_dims(x, 0) if len(x.shape) > 2: @@ -276,16 +264,15 @@ class SileroVADModel: if sr / x.shape[1] > 31.25: raise ValueError("Input audio chunk is too short") - h, c = state + x = np.concatenate([context, x], axis=1) ort_inputs = { "input": x, - "h": h, - "c": c, + "state": state, "sr": np.array(sr, dtype="int64"), } - out, h, c = self.session.run(None, ort_inputs) - state = (h, c) + out, state = self.session.run(None, ort_inputs) + context = x[..., -64:] - return out, state + return out, state, context