Implement VadOptions (#198)

* Implement VadOptions

* Fix line too long

./faster_whisper/transcribe.py:226:101: E501 line too long (111 > 100 characters)

* Reformatted files with black

* black .\faster_whisper\vad.py    
* black .\faster_whisper\transcribe.py

* Fix import order with isort

* isort .\faster_whisper\vad.py
* isort .\faster_whisper\transcribe.py

* Made recommended changes

Recommended in https://github.com/guillaumekln/faster-whisper/pull/198

* Fix typing of vad_options argument

---------

Co-authored-by: Guillaume Klein <guillaumekln@users.noreply.github.com>
This commit is contained in:
FlippFuzz
2023-05-09 18:47:02 +08:00
committed by GitHub
parent d889345e07
commit 5d8f3e2d90
3 changed files with 50 additions and 24 deletions

View File

@@ -15,6 +15,7 @@ from faster_whisper.tokenizer import Tokenizer
from faster_whisper.utils import download_model, format_timestamp, get_logger from faster_whisper.utils import download_model, format_timestamp, get_logger
from faster_whisper.vad import ( from faster_whisper.vad import (
SpeechTimestampsMap, SpeechTimestampsMap,
VadOptions,
collect_chunks, collect_chunks,
get_speech_timestamps, get_speech_timestamps,
) )
@@ -67,6 +68,7 @@ class TranscriptionInfo(NamedTuple):
language_probability: float language_probability: float
duration: float duration: float
transcription_options: TranscriptionOptions transcription_options: TranscriptionOptions
vad_options: VadOptions
class WhisperModel: class WhisperModel:
@@ -177,7 +179,7 @@ class WhisperModel:
prepend_punctuations: str = "\"'“¿([{-", prepend_punctuations: str = "\"'“¿([{-",
append_punctuations: str = "\"'.。,!?::”)]}、", append_punctuations: str = "\"'.。,!?::”)]}、",
vad_filter: bool = False, vad_filter: bool = False,
vad_parameters: Optional[dict] = None, vad_parameters: Optional[Union[dict, VadOptions]] = None,
) -> Tuple[Iterable[Segment], TranscriptionInfo]: ) -> Tuple[Iterable[Segment], TranscriptionInfo]:
"""Transcribes an input file. """Transcribes an input file.
@@ -221,8 +223,8 @@ class WhisperModel:
vad_filter: Enable the voice activity detection (VAD) to filter out parts of the audio vad_filter: Enable the voice activity detection (VAD) to filter out parts of the audio
without speech. This step is using the Silero VAD model without speech. This step is using the Silero VAD model
https://github.com/snakers4/silero-vad. https://github.com/snakers4/silero-vad.
vad_parameters: Dictionary of Silero VAD parameters (see available parameters and vad_parameters: Dictionary of Silero VAD parameters or VadOptions class (see available
default values in the function `get_speech_timestamps`). parameters and default values in the class `VadOptions`).
Returns: Returns:
A tuple with: A tuple with:
@@ -242,8 +244,11 @@ class WhisperModel:
) )
if vad_filter: if vad_filter:
vad_parameters = {} if vad_parameters is None else vad_parameters if vad_parameters is None:
speech_chunks = get_speech_timestamps(audio, **vad_parameters) vad_parameters = VadOptions()
elif isinstance(vad_parameters, dict):
vad_parameters = VadOptions(**vad_parameters)
speech_chunks = get_speech_timestamps(audio, vad_parameters)
audio = collect_chunks(audio, speech_chunks) audio = collect_chunks(audio, speech_chunks)
self.logger.info( self.logger.info(
@@ -330,6 +335,7 @@ class WhisperModel:
language_probability=language_probability, language_probability=language_probability,
duration=duration, duration=duration,
transcription_options=options, transcription_options=options,
vad_options=vad_parameters,
) )
return segments, info return segments, info

View File

@@ -3,47 +3,64 @@ import functools
import os import os
import warnings import warnings
from typing import List, Optional from typing import List, NamedTuple, Optional
import numpy as np import numpy as np
from faster_whisper.utils import get_assets_path from faster_whisper.utils import get_assets_path
# The code below is adapted from https://github.com/snakers4/silero-vad. # The code below is adapted from https://github.com/snakers4/silero-vad.
class VadOptions(NamedTuple):
"""VAD options.
Attributes:
def get_speech_timestamps(
audio: np.ndarray,
*,
threshold: float = 0.5,
min_speech_duration_ms: int = 250,
max_speech_duration_s: float = float("inf"),
min_silence_duration_ms: int = 2000,
window_size_samples: int = 1024,
speech_pad_ms: int = 400,
) -> List[dict]:
"""This method is used for splitting long audios into speech chunks using silero VAD.
Args:
audio: One dimensional float array.
threshold: Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, threshold: Speech threshold. Silero VAD outputs speech probabilities for each audio chunk,
probabilities ABOVE this value are considered as SPEECH. It is better to tune this probabilities ABOVE this value are considered as SPEECH. It is better to tune this
parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets. parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets.
min_speech_duration_ms: Final speech chunks shorter min_speech_duration_ms are thrown out. min_speech_duration_ms: Final speech chunks shorter min_speech_duration_ms are thrown out.
max_speech_duration_s: Maximum duration of speech chunks in seconds. Chunks longer max_speech_duration_s: Maximum duration of speech chunks in seconds. Chunks longer
than max_speech_duration_s will be split at the timestamp of the last silence that than max_speech_duration_s will be split at the timestamp of the last silence that
lasts more than 100s (if any), to prevent agressive cutting. Otherwise, they will be lasts more than 100s (if any), to prevent aggressive cutting. Otherwise, they will be
split aggressively just before max_speech_duration_s. split aggressively just before max_speech_duration_s.
min_silence_duration_ms: In the end of each speech chunk wait for min_silence_duration_ms min_silence_duration_ms: In the end of each speech chunk wait for min_silence_duration_ms
before separating it before separating it
window_size_samples: Audio chunks of window_size_samples size are fed to the silero VAD model. window_size_samples: Audio chunks of window_size_samples size are fed to the silero VAD model.
WARNING! Silero VAD models were trained using 512, 1024, 1536 samples for 16000 sample rate. WARNING! Silero VAD models were trained using 512, 1024, 1536 samples for 16000 sample rate.
Values other than these may affect model perfomance!! Values other than these may affect model performance!!
speech_pad_ms: Final speech chunks are padded by speech_pad_ms each side speech_pad_ms: Final speech chunks are padded by speech_pad_ms each side
"""
threshold: float = 0.5
min_speech_duration_ms: int = 250
max_speech_duration_s: float = float("inf")
min_silence_duration_ms: int = 2000
window_size_samples: int = 1024
speech_pad_ms: int = 400
def get_speech_timestamps(
audio: np.ndarray, vad_options: Optional[VadOptions] = None
) -> List[dict]:
"""This method is used for splitting long audios into speech chunks using silero VAD.
Args:
audio: One dimensional float array.
vad_options: Options for VAD processing.
Returns: Returns:
List of dicts containing begin and end samples of each speech chunk. List of dicts containing begin and end samples of each speech chunk.
""" """
if vad_options is None:
vad_options = VadOptions()
threshold = vad_options.threshold
min_speech_duration_ms = vad_options.min_speech_duration_ms
max_speech_duration_s = vad_options.max_speech_duration_s
min_silence_duration_ms = vad_options.min_silence_duration_ms
window_size_samples = vad_options.window_size_samples
speech_pad_ms = vad_options.speech_pad_ms
if window_size_samples not in [512, 1024, 1536]: if window_size_samples not in [512, 1024, 1536]:
warnings.warn( warnings.warn(
"Unusual window_size_samples! Supported window_size_samples:\n" "Unusual window_size_samples! Supported window_size_samples:\n"

View File

@@ -29,7 +29,7 @@ def test_transcribe(jfk_path):
def test_vad(jfk_path): def test_vad(jfk_path):
model = WhisperModel("tiny") model = WhisperModel("tiny")
segments, _ = model.transcribe( segments, info = model.transcribe(
jfk_path, jfk_path,
vad_filter=True, vad_filter=True,
vad_parameters=dict(min_silence_duration_ms=500, speech_pad_ms=200), vad_parameters=dict(min_silence_duration_ms=500, speech_pad_ms=200),
@@ -47,6 +47,9 @@ def test_vad(jfk_path):
assert 0 < segment.start < 1 assert 0 < segment.start < 1
assert 10 < segment.end < 11 assert 10 < segment.end < 11
assert info.vad_options.min_silence_duration_ms == 500
assert info.vad_options.speech_pad_ms == 200
def test_stereo_diarization(data_dir): def test_stereo_diarization(data_dir):
model = WhisperModel("tiny") model = WhisperModel("tiny")