Support VAD filter (#95)

* Support VAD filter

* Generalize function collect_samples

* Define AudioSegment class

* Only pass prompt and prefix to the first chunk

* Add dict argument vad_parameters

* Fix isort format

* Rename method

* Update README

* Add shortcut when the chunk offset is 0

* Reword readme

* Fix end property

* Concatenate the speech chunks

* Cleanup diff

* Increase default speech pad

* Update README

* Increase default speech pad
This commit is contained in:
Guillaume Klein
2023-04-03 17:22:48 +02:00
committed by GitHub
parent b4c1c57781
commit 19698c95f8
9 changed files with 370 additions and 0 deletions

1
MANIFEST.in Normal file
View File

@@ -0,0 +1 @@
include faster_whisper/assets/silero_vad.onnx

View File

@@ -97,6 +97,22 @@ for segment in segments:
print("[%.2fs -> %.2fs] %s" % (word.start, word.end, word.word))
```
#### VAD filter
The library integrates the [Silero VAD](https://github.com/snakers4/silero-vad) model to filter out parts of the audio without speech:
```python
segments, _ = model.transcribe("audio.mp3", vad_filter=True)
```
The default behavior is conservative and only removes silence longer than 2 seconds. See the available VAD parameters and default values in the function [`get_speech_timestamps`](https://github.com/guillaumekln/faster-whisper/blob/master/faster_whisper/vad.py). They can be customized with the dictionary argument `vad_parameters`:
```python
segments, _ = model.transcribe("audio.mp3", vad_filter=True, vad_parameters=dict(min_silence_duration_ms=500))
```
#### Going further
See more model and transcription options in the [`WhisperModel`](https://github.com/guillaumekln/faster-whisper/blob/master/faster_whisper/transcribe.py) class implementation.
### CLI

Binary file not shown.

View File

@@ -12,6 +12,11 @@ from faster_whisper.audio import decode_audio
from faster_whisper.feature_extractor import FeatureExtractor
from faster_whisper.tokenizer import Tokenizer
from faster_whisper.utils import download_model
from faster_whisper.vad import (
SpeechTimestampsMap,
collect_chunks,
get_speech_timestamps,
)
class Word(NamedTuple):
@@ -152,6 +157,8 @@ class WhisperModel:
word_timestamps: bool = False,
prepend_punctuations: str = "\"'“¿([{-",
append_punctuations: str = "\"'.。,!?::”)]}、",
vad_filter: bool = False,
vad_parameters: Optional[dict] = None,
) -> Tuple[Iterable[Segment], AudioInfo]:
"""Transcribes an input file.
@@ -192,6 +199,11 @@ class WhisperModel:
with the next word
append_punctuations: If word_timestamps is True, merge these punctuation symbols
with the previous word
vad_filter: Enable the voice activity detection (VAD) to filter out parts of the audio
without speech. This step is using the Silero VAD model
https://github.com/snakers4/silero-vad.
vad_parameters: Dictionary of Silero VAD parameters (see available parameters and
default values in the function `get_speech_timestamps`).
Returns:
A tuple with:
@@ -205,6 +217,14 @@ class WhisperModel:
)
duration = audio.shape[0] / self.feature_extractor.sampling_rate
if vad_filter:
vad_parameters = {} if vad_parameters is None else vad_parameters
speech_chunks = get_speech_timestamps(audio, **vad_parameters)
audio = collect_chunks(audio, speech_chunks)
else:
speech_chunks = None
features = self.feature_extractor(audio)
encoder_output = None
@@ -254,6 +274,11 @@ class WhisperModel:
segments = self.generate_segments(features, tokenizer, options, encoder_output)
if speech_chunks:
segments = restore_speech_timestamps(
segments, speech_chunks, self.feature_extractor.sampling_rate
)
audio_info = AudioInfo(
language=language,
language_probability=language_probability,
@@ -678,6 +703,36 @@ class WhisperModel:
]
def restore_speech_timestamps(
segments: Iterable[Segment],
speech_chunks: List[dict],
sampling_rate: int,
) -> Iterable[Segment]:
ts_map = SpeechTimestampsMap(speech_chunks, sampling_rate)
for segment in segments:
if segment.words:
words = []
for word in segment.words:
# Ensure the word start and end times are resolved to the same chunk.
chunk_index = ts_map.get_chunk_index(word.start)
word = word._replace(
start=ts_map.get_original_time(word.start, chunk_index),
end=ts_map.get_original_time(word.end, chunk_index),
)
words.append(word)
else:
words = segment.words
segment = segment._replace(
start=ts_map.get_original_time(segment.start),
end=ts_map.get_original_time(segment.end),
words=words,
)
yield segment
def get_ctranslate2_storage(segment: np.ndarray) -> ctranslate2.StorageView:
segment = np.ascontiguousarray(segment)
segment = ctranslate2.StorageView.from_array(segment)

View File

@@ -1,3 +1,5 @@
import os
from typing import Optional
import huggingface_hub
@@ -18,6 +20,11 @@ _MODELS = (
)
def get_assets_path():
"""Returns the path to the assets directory."""
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
def download_model(size: str, output_dir: Optional[str] = None):
"""Downloads a CTranslate2 Whisper model from the Hugging Face Hub.

268
faster_whisper/vad.py Normal file
View File

@@ -0,0 +1,268 @@
import bisect
import functools
import os
import warnings
from typing import List, Optional
import numpy as np
from faster_whisper.utils import get_assets_path
# The code below is adapted from https://github.com/snakers4/silero-vad.
def get_speech_timestamps(
audio: np.ndarray,
*,
threshold: float = 0.5,
min_speech_duration_ms: int = 250,
max_speech_duration_s: float = float("inf"),
min_silence_duration_ms: int = 2000,
window_size_samples: int = 1024,
speech_pad_ms: int = 200,
) -> List[dict]:
"""This method is used for splitting long audios into speech chunks using silero VAD.
Args:
audio: One dimensional float array.
threshold: Speech threshold. Silero VAD outputs speech probabilities for each audio chunk,
probabilities ABOVE this value are considered as SPEECH. It is better to tune this
parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets.
min_speech_duration_ms: Final speech chunks shorter min_speech_duration_ms are thrown out.
max_speech_duration_s: Maximum duration of speech chunks in seconds. Chunks longer
than max_speech_duration_s will be split at the timestamp of the last silence that
lasts more than 100s (if any), to prevent agressive cutting. Otherwise, they will be
split aggressively just before max_speech_duration_s.
min_silence_duration_ms: In the end of each speech chunk wait for min_silence_duration_ms
before separating it
window_size_samples: Audio chunks of window_size_samples size are fed to the silero VAD model.
WARNING! Silero VAD models were trained using 512, 1024, 1536 samples for 16000 sample rate.
Values other than these may affect model perfomance!!
speech_pad_ms: Final speech chunks are padded by speech_pad_ms each side
Returns:
List of dicts containing begin and end samples of each speech chunk.
"""
if window_size_samples not in [512, 1024, 1536]:
warnings.warn(
"Unusual window_size_samples! Supported window_size_samples:\n"
" - [512, 1024, 1536] for 16000 sampling_rate"
)
sampling_rate = 16000
min_speech_samples = sampling_rate * min_speech_duration_ms / 1000
speech_pad_samples = sampling_rate * speech_pad_ms / 1000
max_speech_samples = (
sampling_rate * max_speech_duration_s
- window_size_samples
- 2 * speech_pad_samples
)
min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
min_silence_samples_at_max_speech = sampling_rate * 98 / 1000
audio_length_samples = len(audio)
model = get_vad_model()
state = model.get_initial_state(batch_size=1)
speech_probs = []
for current_start_sample in range(0, audio_length_samples, window_size_samples):
chunk = audio[current_start_sample : current_start_sample + window_size_samples]
if len(chunk) < window_size_samples:
chunk = np.pad(chunk, (0, int(window_size_samples - len(chunk))))
speech_prob, state = model(chunk, state, sampling_rate)
speech_probs.append(speech_prob)
triggered = False
speeches = []
current_speech = {}
neg_threshold = threshold - 0.15
# to save potential segment end (and tolerate some silence)
temp_end = 0
# to save potential segment limits in case of maximum segment size reached
prev_end = next_start = 0
for i, speech_prob in enumerate(speech_probs):
if (speech_prob >= threshold) and temp_end:
temp_end = 0
if next_start < prev_end:
next_start = window_size_samples * i
if (speech_prob >= threshold) and not triggered:
triggered = True
current_speech["start"] = window_size_samples * i
continue
if (
triggered
and (window_size_samples * i) - current_speech["start"] > max_speech_samples
):
if prev_end:
current_speech["end"] = prev_end
speeches.append(current_speech)
current_speech = {}
# previously reached silence (< neg_thres) and is still not speech (< thres)
if next_start < prev_end:
triggered = False
else:
current_speech["start"] = next_start
prev_end = next_start = temp_end = 0
else:
current_speech["end"] = window_size_samples * i
speeches.append(current_speech)
current_speech = {}
prev_end = next_start = temp_end = 0
triggered = False
continue
if (speech_prob < neg_threshold) and triggered:
if not temp_end:
temp_end = window_size_samples * i
# condition to avoid cutting in very short silence
if (window_size_samples * i) - temp_end > min_silence_samples_at_max_speech:
prev_end = temp_end
if (window_size_samples * i) - temp_end < min_silence_samples:
continue
else:
current_speech["end"] = temp_end
if (
current_speech["end"] - current_speech["start"]
) > min_speech_samples:
speeches.append(current_speech)
current_speech = {}
prev_end = next_start = temp_end = 0
triggered = False
continue
if (
current_speech
and (audio_length_samples - current_speech["start"]) > min_speech_samples
):
current_speech["end"] = audio_length_samples
speeches.append(current_speech)
for i, speech in enumerate(speeches):
if i == 0:
speech["start"] = int(max(0, speech["start"] - speech_pad_samples))
if i != len(speeches) - 1:
silence_duration = speeches[i + 1]["start"] - speech["end"]
if silence_duration < 2 * speech_pad_samples:
speech["end"] += int(silence_duration // 2)
speeches[i + 1]["start"] = int(
max(0, speeches[i + 1]["start"] - silence_duration // 2)
)
else:
speech["end"] = int(
min(audio_length_samples, speech["end"] + speech_pad_samples)
)
speeches[i + 1]["start"] = int(
max(0, speeches[i + 1]["start"] - speech_pad_samples)
)
else:
speech["end"] = int(
min(audio_length_samples, speech["end"] + speech_pad_samples)
)
return speeches
def collect_chunks(audio: np.ndarray, chunks: List[dict]) -> np.ndarray:
"""Collects and concatenates audio chunks."""
if not chunks:
return np.array([], dtype=np.float32)
return np.concatenate([audio[chunk["start"] : chunk["end"]] for chunk in chunks])
class SpeechTimestampsMap:
"""Helper class to restore original speech timestamps."""
def __init__(self, chunks: List[dict], sampling_rate: int, time_precision: int = 2):
self.sampling_rate = sampling_rate
self.time_precision = time_precision
self.chunk_end_sample = []
self.total_silence_before = []
previous_end = 0
silent_samples = 0
for chunk in chunks:
silent_samples += chunk["start"] - previous_end
previous_end = chunk["end"]
self.chunk_end_sample.append(chunk["end"] - silent_samples)
self.total_silence_before.append(silent_samples / sampling_rate)
def get_original_time(
self,
time: float,
chunk_index: Optional[int] = None,
) -> float:
if chunk_index is None:
chunk_index = self.get_chunk_index(time)
total_silence_before = self.total_silence_before[chunk_index]
return round(total_silence_before + time, self.time_precision)
def get_chunk_index(self, time: float) -> int:
sample = int(time * self.sampling_rate)
return bisect.bisect(self.chunk_end_sample, sample)
@functools.lru_cache
def get_vad_model():
"""Returns the VAD model instance."""
path = os.path.join(get_assets_path(), "silero_vad.onnx")
return SileroVADModel(path)
class SileroVADModel:
def __init__(self, path):
try:
import onnxruntime
except ImportError as e:
raise RuntimeError(
"Applying the VAD filter requires the onnxruntime package"
) from e
opts = onnxruntime.SessionOptions()
opts.inter_op_num_threads = 1
opts.intra_op_num_threads = 1
opts.log_severity_level = 4
self.session = onnxruntime.InferenceSession(
path,
providers=["CPUExecutionProvider"],
sess_options=opts,
)
def get_initial_state(self, batch_size: int):
h = np.zeros((2, batch_size, 64), dtype=np.float32)
c = np.zeros((2, batch_size, 64), dtype=np.float32)
return h, c
def __call__(self, x, state, sr: int):
if len(x.shape) == 1:
x = np.expand_dims(x, 0)
if len(x.shape) > 2:
raise ValueError(
f"Too many dimensions for input audio chunk {len(x.shape)}"
)
if sr / x.shape[1] > 31.25:
raise ValueError("Input audio chunk is too short")
h, c = state
ort_inputs = {
"input": x,
"h": h,
"c": c,
"sr": np.array(sr, dtype="int64"),
}
out, h, c = self.session.run(None, ort_inputs)
state = (h, c)
return out, state

View File

@@ -2,3 +2,4 @@ av==10.*
ctranslate2>=3.10,<4
huggingface_hub>=0.13
tokenizers==0.13.*
onnxruntime==1.14.* ; python_version < "3.11"

View File

@@ -56,4 +56,5 @@ setup(
],
},
packages=find_packages(),
include_package_data=True,
)

View File

@@ -27,6 +27,27 @@ def test_transcribe(jfk_path):
assert segment.end == segment.words[-1].end
def test_vad(jfk_path):
model = WhisperModel("tiny")
segments, _ = model.transcribe(
jfk_path,
vad_filter=True,
vad_parameters=dict(min_silence_duration_ms=500),
)
segments = list(segments)
assert len(segments) == 1
segment = segments[0]
assert segment.text == (
" And so my fellow Americans ask not what your country can do for you, "
"ask what you can do for your country."
)
assert 0 < segment.start < 1
assert 10 < segment.end < 11
def test_stereo_diarization(data_dir):
model = WhisperModel("tiny")