Support VAD filter (#95)

* Support VAD filter

* Generalize function collect_samples

* Define AudioSegment class

* Only pass prompt and prefix to the first chunk

* Add dict argument vad_parameters

* Fix isort format

* Rename method

* Update README

* Add shortcut when the chunk offset is 0

* Reword readme

* Fix end property

* Concatenate the speech chunks

* Cleanup diff

* Increase default speech pad

* Update README

* Increase default speech pad
This commit is contained in:
Guillaume Klein
2023-04-03 17:22:48 +02:00
committed by GitHub
parent b4c1c57781
commit 19698c95f8
9 changed files with 370 additions and 0 deletions

View File

@@ -12,6 +12,11 @@ from faster_whisper.audio import decode_audio
from faster_whisper.feature_extractor import FeatureExtractor
from faster_whisper.tokenizer import Tokenizer
from faster_whisper.utils import download_model
from faster_whisper.vad import (
SpeechTimestampsMap,
collect_chunks,
get_speech_timestamps,
)
class Word(NamedTuple):
@@ -152,6 +157,8 @@ class WhisperModel:
word_timestamps: bool = False,
prepend_punctuations: str = "\"'“¿([{-",
append_punctuations: str = "\"'.。,!?::”)]}、",
vad_filter: bool = False,
vad_parameters: Optional[dict] = None,
) -> Tuple[Iterable[Segment], AudioInfo]:
"""Transcribes an input file.
@@ -192,6 +199,11 @@ class WhisperModel:
with the next word
append_punctuations: If word_timestamps is True, merge these punctuation symbols
with the previous word
vad_filter: Enable the voice activity detection (VAD) to filter out parts of the audio
without speech. This step is using the Silero VAD model
https://github.com/snakers4/silero-vad.
vad_parameters: Dictionary of Silero VAD parameters (see available parameters and
default values in the function `get_speech_timestamps`).
Returns:
A tuple with:
@@ -205,6 +217,14 @@ class WhisperModel:
)
duration = audio.shape[0] / self.feature_extractor.sampling_rate
if vad_filter:
vad_parameters = {} if vad_parameters is None else vad_parameters
speech_chunks = get_speech_timestamps(audio, **vad_parameters)
audio = collect_chunks(audio, speech_chunks)
else:
speech_chunks = None
features = self.feature_extractor(audio)
encoder_output = None
@@ -254,6 +274,11 @@ class WhisperModel:
segments = self.generate_segments(features, tokenizer, options, encoder_output)
if speech_chunks:
segments = restore_speech_timestamps(
segments, speech_chunks, self.feature_extractor.sampling_rate
)
audio_info = AudioInfo(
language=language,
language_probability=language_probability,
@@ -678,6 +703,36 @@ class WhisperModel:
]
def restore_speech_timestamps(
segments: Iterable[Segment],
speech_chunks: List[dict],
sampling_rate: int,
) -> Iterable[Segment]:
ts_map = SpeechTimestampsMap(speech_chunks, sampling_rate)
for segment in segments:
if segment.words:
words = []
for word in segment.words:
# Ensure the word start and end times are resolved to the same chunk.
chunk_index = ts_map.get_chunk_index(word.start)
word = word._replace(
start=ts_map.get_original_time(word.start, chunk_index),
end=ts_map.get_original_time(word.end, chunk_index),
)
words.append(word)
else:
words = segment.words
segment = segment._replace(
start=ts_map.get_original_time(segment.start),
end=ts_map.get_original_time(segment.end),
words=words,
)
yield segment
def get_ctranslate2_storage(segment: np.ndarray) -> ctranslate2.StorageView:
segment = np.ascontiguousarray(segment)
segment = ctranslate2.StorageView.from_array(segment)