Support VAD filter (#95)
* Support VAD filter * Generalize function collect_samples * Define AudioSegment class * Only pass prompt and prefix to the first chunk * Add dict argument vad_parameters * Fix isort format * Rename method * Update README * Add shortcut when the chunk offset is 0 * Reword readme * Fix end property * Concatenate the speech chunks * Cleanup diff * Increase default speech pad * Update README * Increase default speech pad
This commit is contained in:
@@ -12,6 +12,11 @@ from faster_whisper.audio import decode_audio
|
||||
from faster_whisper.feature_extractor import FeatureExtractor
|
||||
from faster_whisper.tokenizer import Tokenizer
|
||||
from faster_whisper.utils import download_model
|
||||
from faster_whisper.vad import (
|
||||
SpeechTimestampsMap,
|
||||
collect_chunks,
|
||||
get_speech_timestamps,
|
||||
)
|
||||
|
||||
|
||||
class Word(NamedTuple):
|
||||
@@ -152,6 +157,8 @@ class WhisperModel:
|
||||
word_timestamps: bool = False,
|
||||
prepend_punctuations: str = "\"'“¿([{-",
|
||||
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
||||
vad_filter: bool = False,
|
||||
vad_parameters: Optional[dict] = None,
|
||||
) -> Tuple[Iterable[Segment], AudioInfo]:
|
||||
"""Transcribes an input file.
|
||||
|
||||
@@ -192,6 +199,11 @@ class WhisperModel:
|
||||
with the next word
|
||||
append_punctuations: If word_timestamps is True, merge these punctuation symbols
|
||||
with the previous word
|
||||
vad_filter: Enable the voice activity detection (VAD) to filter out parts of the audio
|
||||
without speech. This step is using the Silero VAD model
|
||||
https://github.com/snakers4/silero-vad.
|
||||
vad_parameters: Dictionary of Silero VAD parameters (see available parameters and
|
||||
default values in the function `get_speech_timestamps`).
|
||||
|
||||
Returns:
|
||||
A tuple with:
|
||||
@@ -205,6 +217,14 @@ class WhisperModel:
|
||||
)
|
||||
|
||||
duration = audio.shape[0] / self.feature_extractor.sampling_rate
|
||||
|
||||
if vad_filter:
|
||||
vad_parameters = {} if vad_parameters is None else vad_parameters
|
||||
speech_chunks = get_speech_timestamps(audio, **vad_parameters)
|
||||
audio = collect_chunks(audio, speech_chunks)
|
||||
else:
|
||||
speech_chunks = None
|
||||
|
||||
features = self.feature_extractor(audio)
|
||||
|
||||
encoder_output = None
|
||||
@@ -254,6 +274,11 @@ class WhisperModel:
|
||||
|
||||
segments = self.generate_segments(features, tokenizer, options, encoder_output)
|
||||
|
||||
if speech_chunks:
|
||||
segments = restore_speech_timestamps(
|
||||
segments, speech_chunks, self.feature_extractor.sampling_rate
|
||||
)
|
||||
|
||||
audio_info = AudioInfo(
|
||||
language=language,
|
||||
language_probability=language_probability,
|
||||
@@ -678,6 +703,36 @@ class WhisperModel:
|
||||
]
|
||||
|
||||
|
||||
def restore_speech_timestamps(
|
||||
segments: Iterable[Segment],
|
||||
speech_chunks: List[dict],
|
||||
sampling_rate: int,
|
||||
) -> Iterable[Segment]:
|
||||
ts_map = SpeechTimestampsMap(speech_chunks, sampling_rate)
|
||||
|
||||
for segment in segments:
|
||||
if segment.words:
|
||||
words = []
|
||||
for word in segment.words:
|
||||
# Ensure the word start and end times are resolved to the same chunk.
|
||||
chunk_index = ts_map.get_chunk_index(word.start)
|
||||
word = word._replace(
|
||||
start=ts_map.get_original_time(word.start, chunk_index),
|
||||
end=ts_map.get_original_time(word.end, chunk_index),
|
||||
)
|
||||
words.append(word)
|
||||
else:
|
||||
words = segment.words
|
||||
|
||||
segment = segment._replace(
|
||||
start=ts_map.get_original_time(segment.start),
|
||||
end=ts_map.get_original_time(segment.end),
|
||||
words=words,
|
||||
)
|
||||
|
||||
yield segment
|
||||
|
||||
|
||||
def get_ctranslate2_storage(segment: np.ndarray) -> ctranslate2.StorageView:
|
||||
segment = np.ascontiguousarray(segment)
|
||||
segment = ctranslate2.StorageView.from_array(segment)
|
||||
|
||||
Reference in New Issue
Block a user