From 5d8f3e2d905339b2d36ea4c73085daaf213fc548 Mon Sep 17 00:00:00 2001 From: FlippFuzz <41221030+FlippFuzz@users.noreply.github.com> Date: Tue, 9 May 2023 18:47:02 +0800 Subject: [PATCH] Implement VadOptions (#198) * Implement VadOptions * Fix line too long ./faster_whisper/transcribe.py:226:101: E501 line too long (111 > 100 characters) * Reformatted files with black * black .\faster_whisper\vad.py * black .\faster_whisper\transcribe.py * Fix import order with isort * isort .\faster_whisper\vad.py * isort .\faster_whisper\transcribe.py * Made recommended changes Recommended in https://github.com/guillaumekln/faster-whisper/pull/198 * Fix typing of vad_options argument --------- Co-authored-by: Guillaume Klein --- faster_whisper/transcribe.py | 16 +++++++---- faster_whisper/vad.py | 53 ++++++++++++++++++++++++------------ tests/test_transcribe.py | 5 +++- 3 files changed, 50 insertions(+), 24 deletions(-) diff --git a/faster_whisper/transcribe.py b/faster_whisper/transcribe.py index 80aade4..06154f3 100644 --- a/faster_whisper/transcribe.py +++ b/faster_whisper/transcribe.py @@ -15,6 +15,7 @@ from faster_whisper.tokenizer import Tokenizer from faster_whisper.utils import download_model, format_timestamp, get_logger from faster_whisper.vad import ( SpeechTimestampsMap, + VadOptions, collect_chunks, get_speech_timestamps, ) @@ -67,6 +68,7 @@ class TranscriptionInfo(NamedTuple): language_probability: float duration: float transcription_options: TranscriptionOptions + vad_options: VadOptions class WhisperModel: @@ -177,7 +179,7 @@ class WhisperModel: prepend_punctuations: str = "\"'“¿([{-", append_punctuations: str = "\"'.。,,!!??::”)]}、", vad_filter: bool = False, - vad_parameters: Optional[dict] = None, + vad_parameters: Optional[Union[dict, VadOptions]] = None, ) -> Tuple[Iterable[Segment], TranscriptionInfo]: """Transcribes an input file. @@ -221,8 +223,8 @@ class WhisperModel: vad_filter: Enable the voice activity detection (VAD) to filter out parts of the audio without speech. This step is using the Silero VAD model https://github.com/snakers4/silero-vad. - vad_parameters: Dictionary of Silero VAD parameters (see available parameters and - default values in the function `get_speech_timestamps`). + vad_parameters: Dictionary of Silero VAD parameters or VadOptions class (see available + parameters and default values in the class `VadOptions`). Returns: A tuple with: @@ -242,8 +244,11 @@ class WhisperModel: ) if vad_filter: - vad_parameters = {} if vad_parameters is None else vad_parameters - speech_chunks = get_speech_timestamps(audio, **vad_parameters) + if vad_parameters is None: + vad_parameters = VadOptions() + elif isinstance(vad_parameters, dict): + vad_parameters = VadOptions(**vad_parameters) + speech_chunks = get_speech_timestamps(audio, vad_parameters) audio = collect_chunks(audio, speech_chunks) self.logger.info( @@ -330,6 +335,7 @@ class WhisperModel: language_probability=language_probability, duration=duration, transcription_options=options, + vad_options=vad_parameters, ) return segments, info diff --git a/faster_whisper/vad.py b/faster_whisper/vad.py index cf14d5c..cf3b626 100644 --- a/faster_whisper/vad.py +++ b/faster_whisper/vad.py @@ -3,47 +3,64 @@ import functools import os import warnings -from typing import List, Optional +from typing import List, NamedTuple, Optional import numpy as np from faster_whisper.utils import get_assets_path + # The code below is adapted from https://github.com/snakers4/silero-vad. +class VadOptions(NamedTuple): + """VAD options. - -def get_speech_timestamps( - audio: np.ndarray, - *, - threshold: float = 0.5, - min_speech_duration_ms: int = 250, - max_speech_duration_s: float = float("inf"), - min_silence_duration_ms: int = 2000, - window_size_samples: int = 1024, - speech_pad_ms: int = 400, -) -> List[dict]: - """This method is used for splitting long audios into speech chunks using silero VAD. - - Args: - audio: One dimensional float array. + Attributes: threshold: Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, probabilities ABOVE this value are considered as SPEECH. It is better to tune this parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets. min_speech_duration_ms: Final speech chunks shorter min_speech_duration_ms are thrown out. max_speech_duration_s: Maximum duration of speech chunks in seconds. Chunks longer than max_speech_duration_s will be split at the timestamp of the last silence that - lasts more than 100s (if any), to prevent agressive cutting. Otherwise, they will be + lasts more than 100s (if any), to prevent aggressive cutting. Otherwise, they will be split aggressively just before max_speech_duration_s. min_silence_duration_ms: In the end of each speech chunk wait for min_silence_duration_ms before separating it window_size_samples: Audio chunks of window_size_samples size are fed to the silero VAD model. WARNING! Silero VAD models were trained using 512, 1024, 1536 samples for 16000 sample rate. - Values other than these may affect model perfomance!! + Values other than these may affect model performance!! speech_pad_ms: Final speech chunks are padded by speech_pad_ms each side + """ + + threshold: float = 0.5 + min_speech_duration_ms: int = 250 + max_speech_duration_s: float = float("inf") + min_silence_duration_ms: int = 2000 + window_size_samples: int = 1024 + speech_pad_ms: int = 400 + + +def get_speech_timestamps( + audio: np.ndarray, vad_options: Optional[VadOptions] = None +) -> List[dict]: + """This method is used for splitting long audios into speech chunks using silero VAD. + + Args: + audio: One dimensional float array. + vad_options: Options for VAD processing. Returns: List of dicts containing begin and end samples of each speech chunk. """ + if vad_options is None: + vad_options = VadOptions() + + threshold = vad_options.threshold + min_speech_duration_ms = vad_options.min_speech_duration_ms + max_speech_duration_s = vad_options.max_speech_duration_s + min_silence_duration_ms = vad_options.min_silence_duration_ms + window_size_samples = vad_options.window_size_samples + speech_pad_ms = vad_options.speech_pad_ms + if window_size_samples not in [512, 1024, 1536]: warnings.warn( "Unusual window_size_samples! Supported window_size_samples:\n" diff --git a/tests/test_transcribe.py b/tests/test_transcribe.py index 8bebd2a..f1c9572 100644 --- a/tests/test_transcribe.py +++ b/tests/test_transcribe.py @@ -29,7 +29,7 @@ def test_transcribe(jfk_path): def test_vad(jfk_path): model = WhisperModel("tiny") - segments, _ = model.transcribe( + segments, info = model.transcribe( jfk_path, vad_filter=True, vad_parameters=dict(min_silence_duration_ms=500, speech_pad_ms=200), @@ -47,6 +47,9 @@ def test_vad(jfk_path): assert 0 < segment.start < 1 assert 10 < segment.end < 11 + assert info.vad_options.min_silence_duration_ms == 500 + assert info.vad_options.speech_pad_ms == 200 + def test_stereo_diarization(data_dir): model = WhisperModel("tiny")