diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000..e2fff83 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1 @@ +include faster_whisper/assets/silero_vad.onnx diff --git a/README.md b/README.md index 87c434d..a2adf26 100644 --- a/README.md +++ b/README.md @@ -97,6 +97,22 @@ for segment in segments: print("[%.2fs -> %.2fs] %s" % (word.start, word.end, word.word)) ``` +#### VAD filter + +The library integrates the [Silero VAD](https://github.com/snakers4/silero-vad) model to filter out parts of the audio without speech: + +```python +segments, _ = model.transcribe("audio.mp3", vad_filter=True) +``` + +The default behavior is conservative and only removes silence longer than 2 seconds. See the available VAD parameters and default values in the function [`get_speech_timestamps`](https://github.com/guillaumekln/faster-whisper/blob/master/faster_whisper/vad.py). They can be customized with the dictionary argument `vad_parameters`: + +```python +segments, _ = model.transcribe("audio.mp3", vad_filter=True, vad_parameters=dict(min_silence_duration_ms=500)) +``` + +#### Going further + See more model and transcription options in the [`WhisperModel`](https://github.com/guillaumekln/faster-whisper/blob/master/faster_whisper/transcribe.py) class implementation. ### CLI diff --git a/faster_whisper/assets/silero_vad.onnx b/faster_whisper/assets/silero_vad.onnx new file mode 100644 index 0000000..5c21912 Binary files /dev/null and b/faster_whisper/assets/silero_vad.onnx differ diff --git a/faster_whisper/transcribe.py b/faster_whisper/transcribe.py index 21622db..a31cf1e 100644 --- a/faster_whisper/transcribe.py +++ b/faster_whisper/transcribe.py @@ -12,6 +12,11 @@ from faster_whisper.audio import decode_audio from faster_whisper.feature_extractor import FeatureExtractor from faster_whisper.tokenizer import Tokenizer from faster_whisper.utils import download_model +from faster_whisper.vad import ( + SpeechTimestampsMap, + collect_chunks, + get_speech_timestamps, +) class Word(NamedTuple): @@ -152,6 +157,8 @@ class WhisperModel: word_timestamps: bool = False, prepend_punctuations: str = "\"'“¿([{-", append_punctuations: str = "\"'.。,,!!??::”)]}、", + vad_filter: bool = False, + vad_parameters: Optional[dict] = None, ) -> Tuple[Iterable[Segment], AudioInfo]: """Transcribes an input file. @@ -192,6 +199,11 @@ class WhisperModel: with the next word append_punctuations: If word_timestamps is True, merge these punctuation symbols with the previous word + vad_filter: Enable the voice activity detection (VAD) to filter out parts of the audio + without speech. This step is using the Silero VAD model + https://github.com/snakers4/silero-vad. + vad_parameters: Dictionary of Silero VAD parameters (see available parameters and + default values in the function `get_speech_timestamps`). Returns: A tuple with: @@ -205,6 +217,14 @@ class WhisperModel: ) duration = audio.shape[0] / self.feature_extractor.sampling_rate + + if vad_filter: + vad_parameters = {} if vad_parameters is None else vad_parameters + speech_chunks = get_speech_timestamps(audio, **vad_parameters) + audio = collect_chunks(audio, speech_chunks) + else: + speech_chunks = None + features = self.feature_extractor(audio) encoder_output = None @@ -254,6 +274,11 @@ class WhisperModel: segments = self.generate_segments(features, tokenizer, options, encoder_output) + if speech_chunks: + segments = restore_speech_timestamps( + segments, speech_chunks, self.feature_extractor.sampling_rate + ) + audio_info = AudioInfo( language=language, language_probability=language_probability, @@ -678,6 +703,36 @@ class WhisperModel: ] +def restore_speech_timestamps( + segments: Iterable[Segment], + speech_chunks: List[dict], + sampling_rate: int, +) -> Iterable[Segment]: + ts_map = SpeechTimestampsMap(speech_chunks, sampling_rate) + + for segment in segments: + if segment.words: + words = [] + for word in segment.words: + # Ensure the word start and end times are resolved to the same chunk. + chunk_index = ts_map.get_chunk_index(word.start) + word = word._replace( + start=ts_map.get_original_time(word.start, chunk_index), + end=ts_map.get_original_time(word.end, chunk_index), + ) + words.append(word) + else: + words = segment.words + + segment = segment._replace( + start=ts_map.get_original_time(segment.start), + end=ts_map.get_original_time(segment.end), + words=words, + ) + + yield segment + + def get_ctranslate2_storage(segment: np.ndarray) -> ctranslate2.StorageView: segment = np.ascontiguousarray(segment) segment = ctranslate2.StorageView.from_array(segment) diff --git a/faster_whisper/utils.py b/faster_whisper/utils.py index 52cf03c..71ec9d5 100644 --- a/faster_whisper/utils.py +++ b/faster_whisper/utils.py @@ -1,3 +1,5 @@ +import os + from typing import Optional import huggingface_hub @@ -18,6 +20,11 @@ _MODELS = ( ) +def get_assets_path(): + """Returns the path to the assets directory.""" + return os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets") + + def download_model(size: str, output_dir: Optional[str] = None): """Downloads a CTranslate2 Whisper model from the Hugging Face Hub. diff --git a/faster_whisper/vad.py b/faster_whisper/vad.py new file mode 100644 index 0000000..5131f2e --- /dev/null +++ b/faster_whisper/vad.py @@ -0,0 +1,268 @@ +import bisect +import functools +import os +import warnings + +from typing import List, Optional + +import numpy as np + +from faster_whisper.utils import get_assets_path + +# The code below is adapted from https://github.com/snakers4/silero-vad. + + +def get_speech_timestamps( + audio: np.ndarray, + *, + threshold: float = 0.5, + min_speech_duration_ms: int = 250, + max_speech_duration_s: float = float("inf"), + min_silence_duration_ms: int = 2000, + window_size_samples: int = 1024, + speech_pad_ms: int = 200, +) -> List[dict]: + """This method is used for splitting long audios into speech chunks using silero VAD. + + Args: + audio: One dimensional float array. + threshold: Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, + probabilities ABOVE this value are considered as SPEECH. It is better to tune this + parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets. + min_speech_duration_ms: Final speech chunks shorter min_speech_duration_ms are thrown out. + max_speech_duration_s: Maximum duration of speech chunks in seconds. Chunks longer + than max_speech_duration_s will be split at the timestamp of the last silence that + lasts more than 100s (if any), to prevent agressive cutting. Otherwise, they will be + split aggressively just before max_speech_duration_s. + min_silence_duration_ms: In the end of each speech chunk wait for min_silence_duration_ms + before separating it + window_size_samples: Audio chunks of window_size_samples size are fed to the silero VAD model. + WARNING! Silero VAD models were trained using 512, 1024, 1536 samples for 16000 sample rate. + Values other than these may affect model perfomance!! + speech_pad_ms: Final speech chunks are padded by speech_pad_ms each side + + Returns: + List of dicts containing begin and end samples of each speech chunk. + """ + if window_size_samples not in [512, 1024, 1536]: + warnings.warn( + "Unusual window_size_samples! Supported window_size_samples:\n" + " - [512, 1024, 1536] for 16000 sampling_rate" + ) + + sampling_rate = 16000 + min_speech_samples = sampling_rate * min_speech_duration_ms / 1000 + speech_pad_samples = sampling_rate * speech_pad_ms / 1000 + max_speech_samples = ( + sampling_rate * max_speech_duration_s + - window_size_samples + - 2 * speech_pad_samples + ) + min_silence_samples = sampling_rate * min_silence_duration_ms / 1000 + min_silence_samples_at_max_speech = sampling_rate * 98 / 1000 + + audio_length_samples = len(audio) + + model = get_vad_model() + state = model.get_initial_state(batch_size=1) + + speech_probs = [] + for current_start_sample in range(0, audio_length_samples, window_size_samples): + chunk = audio[current_start_sample : current_start_sample + window_size_samples] + if len(chunk) < window_size_samples: + chunk = np.pad(chunk, (0, int(window_size_samples - len(chunk)))) + speech_prob, state = model(chunk, state, sampling_rate) + speech_probs.append(speech_prob) + + triggered = False + speeches = [] + current_speech = {} + neg_threshold = threshold - 0.15 + + # to save potential segment end (and tolerate some silence) + temp_end = 0 + # to save potential segment limits in case of maximum segment size reached + prev_end = next_start = 0 + + for i, speech_prob in enumerate(speech_probs): + if (speech_prob >= threshold) and temp_end: + temp_end = 0 + if next_start < prev_end: + next_start = window_size_samples * i + + if (speech_prob >= threshold) and not triggered: + triggered = True + current_speech["start"] = window_size_samples * i + continue + + if ( + triggered + and (window_size_samples * i) - current_speech["start"] > max_speech_samples + ): + if prev_end: + current_speech["end"] = prev_end + speeches.append(current_speech) + current_speech = {} + # previously reached silence (< neg_thres) and is still not speech (< thres) + if next_start < prev_end: + triggered = False + else: + current_speech["start"] = next_start + prev_end = next_start = temp_end = 0 + else: + current_speech["end"] = window_size_samples * i + speeches.append(current_speech) + current_speech = {} + prev_end = next_start = temp_end = 0 + triggered = False + continue + + if (speech_prob < neg_threshold) and triggered: + if not temp_end: + temp_end = window_size_samples * i + # condition to avoid cutting in very short silence + if (window_size_samples * i) - temp_end > min_silence_samples_at_max_speech: + prev_end = temp_end + if (window_size_samples * i) - temp_end < min_silence_samples: + continue + else: + current_speech["end"] = temp_end + if ( + current_speech["end"] - current_speech["start"] + ) > min_speech_samples: + speeches.append(current_speech) + current_speech = {} + prev_end = next_start = temp_end = 0 + triggered = False + continue + + if ( + current_speech + and (audio_length_samples - current_speech["start"]) > min_speech_samples + ): + current_speech["end"] = audio_length_samples + speeches.append(current_speech) + + for i, speech in enumerate(speeches): + if i == 0: + speech["start"] = int(max(0, speech["start"] - speech_pad_samples)) + if i != len(speeches) - 1: + silence_duration = speeches[i + 1]["start"] - speech["end"] + if silence_duration < 2 * speech_pad_samples: + speech["end"] += int(silence_duration // 2) + speeches[i + 1]["start"] = int( + max(0, speeches[i + 1]["start"] - silence_duration // 2) + ) + else: + speech["end"] = int( + min(audio_length_samples, speech["end"] + speech_pad_samples) + ) + speeches[i + 1]["start"] = int( + max(0, speeches[i + 1]["start"] - speech_pad_samples) + ) + else: + speech["end"] = int( + min(audio_length_samples, speech["end"] + speech_pad_samples) + ) + + return speeches + + +def collect_chunks(audio: np.ndarray, chunks: List[dict]) -> np.ndarray: + """Collects and concatenates audio chunks.""" + if not chunks: + return np.array([], dtype=np.float32) + + return np.concatenate([audio[chunk["start"] : chunk["end"]] for chunk in chunks]) + + +class SpeechTimestampsMap: + """Helper class to restore original speech timestamps.""" + + def __init__(self, chunks: List[dict], sampling_rate: int, time_precision: int = 2): + self.sampling_rate = sampling_rate + self.time_precision = time_precision + self.chunk_end_sample = [] + self.total_silence_before = [] + + previous_end = 0 + silent_samples = 0 + + for chunk in chunks: + silent_samples += chunk["start"] - previous_end + previous_end = chunk["end"] + + self.chunk_end_sample.append(chunk["end"] - silent_samples) + self.total_silence_before.append(silent_samples / sampling_rate) + + def get_original_time( + self, + time: float, + chunk_index: Optional[int] = None, + ) -> float: + if chunk_index is None: + chunk_index = self.get_chunk_index(time) + + total_silence_before = self.total_silence_before[chunk_index] + return round(total_silence_before + time, self.time_precision) + + def get_chunk_index(self, time: float) -> int: + sample = int(time * self.sampling_rate) + return bisect.bisect(self.chunk_end_sample, sample) + + +@functools.lru_cache +def get_vad_model(): + """Returns the VAD model instance.""" + path = os.path.join(get_assets_path(), "silero_vad.onnx") + return SileroVADModel(path) + + +class SileroVADModel: + def __init__(self, path): + try: + import onnxruntime + except ImportError as e: + raise RuntimeError( + "Applying the VAD filter requires the onnxruntime package" + ) from e + + opts = onnxruntime.SessionOptions() + opts.inter_op_num_threads = 1 + opts.intra_op_num_threads = 1 + opts.log_severity_level = 4 + + self.session = onnxruntime.InferenceSession( + path, + providers=["CPUExecutionProvider"], + sess_options=opts, + ) + + def get_initial_state(self, batch_size: int): + h = np.zeros((2, batch_size, 64), dtype=np.float32) + c = np.zeros((2, batch_size, 64), dtype=np.float32) + return h, c + + def __call__(self, x, state, sr: int): + if len(x.shape) == 1: + x = np.expand_dims(x, 0) + if len(x.shape) > 2: + raise ValueError( + f"Too many dimensions for input audio chunk {len(x.shape)}" + ) + if sr / x.shape[1] > 31.25: + raise ValueError("Input audio chunk is too short") + + h, c = state + + ort_inputs = { + "input": x, + "h": h, + "c": c, + "sr": np.array(sr, dtype="int64"), + } + + out, h, c = self.session.run(None, ort_inputs) + state = (h, c) + + return out, state diff --git a/requirements.txt b/requirements.txt index a8eb983..73c3b6d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,3 +2,4 @@ av==10.* ctranslate2>=3.10,<4 huggingface_hub>=0.13 tokenizers==0.13.* +onnxruntime==1.14.* ; python_version < "3.11" diff --git a/setup.py b/setup.py index ec01689..e013715 100644 --- a/setup.py +++ b/setup.py @@ -56,4 +56,5 @@ setup( ], }, packages=find_packages(), + include_package_data=True, ) diff --git a/tests/test_transcribe.py b/tests/test_transcribe.py index 10e39db..5406535 100644 --- a/tests/test_transcribe.py +++ b/tests/test_transcribe.py @@ -27,6 +27,27 @@ def test_transcribe(jfk_path): assert segment.end == segment.words[-1].end +def test_vad(jfk_path): + model = WhisperModel("tiny") + segments, _ = model.transcribe( + jfk_path, + vad_filter=True, + vad_parameters=dict(min_silence_duration_ms=500), + ) + segments = list(segments) + + assert len(segments) == 1 + segment = segments[0] + + assert segment.text == ( + " And so my fellow Americans ask not what your country can do for you, " + "ask what you can do for your country." + ) + + assert 0 < segment.start < 1 + assert 10 < segment.end < 11 + + def test_stereo_diarization(data_dir): model = WhisperModel("tiny")