Implement VadOptions (#198)
* Implement VadOptions * Fix line too long ./faster_whisper/transcribe.py:226:101: E501 line too long (111 > 100 characters) * Reformatted files with black * black .\faster_whisper\vad.py * black .\faster_whisper\transcribe.py * Fix import order with isort * isort .\faster_whisper\vad.py * isort .\faster_whisper\transcribe.py * Made recommended changes Recommended in https://github.com/guillaumekln/faster-whisper/pull/198 * Fix typing of vad_options argument --------- Co-authored-by: Guillaume Klein <guillaumekln@users.noreply.github.com>
This commit is contained in:
@@ -15,6 +15,7 @@ from faster_whisper.tokenizer import Tokenizer
|
|||||||
from faster_whisper.utils import download_model, format_timestamp, get_logger
|
from faster_whisper.utils import download_model, format_timestamp, get_logger
|
||||||
from faster_whisper.vad import (
|
from faster_whisper.vad import (
|
||||||
SpeechTimestampsMap,
|
SpeechTimestampsMap,
|
||||||
|
VadOptions,
|
||||||
collect_chunks,
|
collect_chunks,
|
||||||
get_speech_timestamps,
|
get_speech_timestamps,
|
||||||
)
|
)
|
||||||
@@ -67,6 +68,7 @@ class TranscriptionInfo(NamedTuple):
|
|||||||
language_probability: float
|
language_probability: float
|
||||||
duration: float
|
duration: float
|
||||||
transcription_options: TranscriptionOptions
|
transcription_options: TranscriptionOptions
|
||||||
|
vad_options: VadOptions
|
||||||
|
|
||||||
|
|
||||||
class WhisperModel:
|
class WhisperModel:
|
||||||
@@ -177,7 +179,7 @@ class WhisperModel:
|
|||||||
prepend_punctuations: str = "\"'“¿([{-",
|
prepend_punctuations: str = "\"'“¿([{-",
|
||||||
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
||||||
vad_filter: bool = False,
|
vad_filter: bool = False,
|
||||||
vad_parameters: Optional[dict] = None,
|
vad_parameters: Optional[Union[dict, VadOptions]] = None,
|
||||||
) -> Tuple[Iterable[Segment], TranscriptionInfo]:
|
) -> Tuple[Iterable[Segment], TranscriptionInfo]:
|
||||||
"""Transcribes an input file.
|
"""Transcribes an input file.
|
||||||
|
|
||||||
@@ -221,8 +223,8 @@ class WhisperModel:
|
|||||||
vad_filter: Enable the voice activity detection (VAD) to filter out parts of the audio
|
vad_filter: Enable the voice activity detection (VAD) to filter out parts of the audio
|
||||||
without speech. This step is using the Silero VAD model
|
without speech. This step is using the Silero VAD model
|
||||||
https://github.com/snakers4/silero-vad.
|
https://github.com/snakers4/silero-vad.
|
||||||
vad_parameters: Dictionary of Silero VAD parameters (see available parameters and
|
vad_parameters: Dictionary of Silero VAD parameters or VadOptions class (see available
|
||||||
default values in the function `get_speech_timestamps`).
|
parameters and default values in the class `VadOptions`).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tuple with:
|
A tuple with:
|
||||||
@@ -242,8 +244,11 @@ class WhisperModel:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if vad_filter:
|
if vad_filter:
|
||||||
vad_parameters = {} if vad_parameters is None else vad_parameters
|
if vad_parameters is None:
|
||||||
speech_chunks = get_speech_timestamps(audio, **vad_parameters)
|
vad_parameters = VadOptions()
|
||||||
|
elif isinstance(vad_parameters, dict):
|
||||||
|
vad_parameters = VadOptions(**vad_parameters)
|
||||||
|
speech_chunks = get_speech_timestamps(audio, vad_parameters)
|
||||||
audio = collect_chunks(audio, speech_chunks)
|
audio = collect_chunks(audio, speech_chunks)
|
||||||
|
|
||||||
self.logger.info(
|
self.logger.info(
|
||||||
@@ -330,6 +335,7 @@ class WhisperModel:
|
|||||||
language_probability=language_probability,
|
language_probability=language_probability,
|
||||||
duration=duration,
|
duration=duration,
|
||||||
transcription_options=options,
|
transcription_options=options,
|
||||||
|
vad_options=vad_parameters,
|
||||||
)
|
)
|
||||||
|
|
||||||
return segments, info
|
return segments, info
|
||||||
|
|||||||
@@ -3,47 +3,64 @@ import functools
|
|||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
from typing import List, Optional
|
from typing import List, NamedTuple, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from faster_whisper.utils import get_assets_path
|
from faster_whisper.utils import get_assets_path
|
||||||
|
|
||||||
|
|
||||||
# The code below is adapted from https://github.com/snakers4/silero-vad.
|
# The code below is adapted from https://github.com/snakers4/silero-vad.
|
||||||
|
class VadOptions(NamedTuple):
|
||||||
|
"""VAD options.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
def get_speech_timestamps(
|
|
||||||
audio: np.ndarray,
|
|
||||||
*,
|
|
||||||
threshold: float = 0.5,
|
|
||||||
min_speech_duration_ms: int = 250,
|
|
||||||
max_speech_duration_s: float = float("inf"),
|
|
||||||
min_silence_duration_ms: int = 2000,
|
|
||||||
window_size_samples: int = 1024,
|
|
||||||
speech_pad_ms: int = 400,
|
|
||||||
) -> List[dict]:
|
|
||||||
"""This method is used for splitting long audios into speech chunks using silero VAD.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
audio: One dimensional float array.
|
|
||||||
threshold: Speech threshold. Silero VAD outputs speech probabilities for each audio chunk,
|
threshold: Speech threshold. Silero VAD outputs speech probabilities for each audio chunk,
|
||||||
probabilities ABOVE this value are considered as SPEECH. It is better to tune this
|
probabilities ABOVE this value are considered as SPEECH. It is better to tune this
|
||||||
parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets.
|
parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets.
|
||||||
min_speech_duration_ms: Final speech chunks shorter min_speech_duration_ms are thrown out.
|
min_speech_duration_ms: Final speech chunks shorter min_speech_duration_ms are thrown out.
|
||||||
max_speech_duration_s: Maximum duration of speech chunks in seconds. Chunks longer
|
max_speech_duration_s: Maximum duration of speech chunks in seconds. Chunks longer
|
||||||
than max_speech_duration_s will be split at the timestamp of the last silence that
|
than max_speech_duration_s will be split at the timestamp of the last silence that
|
||||||
lasts more than 100s (if any), to prevent agressive cutting. Otherwise, they will be
|
lasts more than 100s (if any), to prevent aggressive cutting. Otherwise, they will be
|
||||||
split aggressively just before max_speech_duration_s.
|
split aggressively just before max_speech_duration_s.
|
||||||
min_silence_duration_ms: In the end of each speech chunk wait for min_silence_duration_ms
|
min_silence_duration_ms: In the end of each speech chunk wait for min_silence_duration_ms
|
||||||
before separating it
|
before separating it
|
||||||
window_size_samples: Audio chunks of window_size_samples size are fed to the silero VAD model.
|
window_size_samples: Audio chunks of window_size_samples size are fed to the silero VAD model.
|
||||||
WARNING! Silero VAD models were trained using 512, 1024, 1536 samples for 16000 sample rate.
|
WARNING! Silero VAD models were trained using 512, 1024, 1536 samples for 16000 sample rate.
|
||||||
Values other than these may affect model perfomance!!
|
Values other than these may affect model performance!!
|
||||||
speech_pad_ms: Final speech chunks are padded by speech_pad_ms each side
|
speech_pad_ms: Final speech chunks are padded by speech_pad_ms each side
|
||||||
|
"""
|
||||||
|
|
||||||
|
threshold: float = 0.5
|
||||||
|
min_speech_duration_ms: int = 250
|
||||||
|
max_speech_duration_s: float = float("inf")
|
||||||
|
min_silence_duration_ms: int = 2000
|
||||||
|
window_size_samples: int = 1024
|
||||||
|
speech_pad_ms: int = 400
|
||||||
|
|
||||||
|
|
||||||
|
def get_speech_timestamps(
|
||||||
|
audio: np.ndarray, vad_options: Optional[VadOptions] = None
|
||||||
|
) -> List[dict]:
|
||||||
|
"""This method is used for splitting long audios into speech chunks using silero VAD.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio: One dimensional float array.
|
||||||
|
vad_options: Options for VAD processing.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of dicts containing begin and end samples of each speech chunk.
|
List of dicts containing begin and end samples of each speech chunk.
|
||||||
"""
|
"""
|
||||||
|
if vad_options is None:
|
||||||
|
vad_options = VadOptions()
|
||||||
|
|
||||||
|
threshold = vad_options.threshold
|
||||||
|
min_speech_duration_ms = vad_options.min_speech_duration_ms
|
||||||
|
max_speech_duration_s = vad_options.max_speech_duration_s
|
||||||
|
min_silence_duration_ms = vad_options.min_silence_duration_ms
|
||||||
|
window_size_samples = vad_options.window_size_samples
|
||||||
|
speech_pad_ms = vad_options.speech_pad_ms
|
||||||
|
|
||||||
if window_size_samples not in [512, 1024, 1536]:
|
if window_size_samples not in [512, 1024, 1536]:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"Unusual window_size_samples! Supported window_size_samples:\n"
|
"Unusual window_size_samples! Supported window_size_samples:\n"
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ def test_transcribe(jfk_path):
|
|||||||
|
|
||||||
def test_vad(jfk_path):
|
def test_vad(jfk_path):
|
||||||
model = WhisperModel("tiny")
|
model = WhisperModel("tiny")
|
||||||
segments, _ = model.transcribe(
|
segments, info = model.transcribe(
|
||||||
jfk_path,
|
jfk_path,
|
||||||
vad_filter=True,
|
vad_filter=True,
|
||||||
vad_parameters=dict(min_silence_duration_ms=500, speech_pad_ms=200),
|
vad_parameters=dict(min_silence_duration_ms=500, speech_pad_ms=200),
|
||||||
@@ -47,6 +47,9 @@ def test_vad(jfk_path):
|
|||||||
assert 0 < segment.start < 1
|
assert 0 < segment.start < 1
|
||||||
assert 10 < segment.end < 11
|
assert 10 < segment.end < 11
|
||||||
|
|
||||||
|
assert info.vad_options.min_silence_duration_ms == 500
|
||||||
|
assert info.vad_options.speech_pad_ms == 200
|
||||||
|
|
||||||
|
|
||||||
def test_stereo_diarization(data_dir):
|
def test_stereo_diarization(data_dir):
|
||||||
model = WhisperModel("tiny")
|
model = WhisperModel("tiny")
|
||||||
|
|||||||
Reference in New Issue
Block a user