Merge remote-tracking branch 'upstream/master' into prompt

This commit is contained in:
2024-09-04 17:48:06 +08:00
13 changed files with 1599 additions and 423 deletions

View File

@@ -1,3 +1,4 @@
include faster_whisper/assets/silero_vad.onnx
include requirements.txt
include requirements.conversion.txt
include faster_whisper/assets/pyannote_vad_model.bin

View File

@@ -69,7 +69,6 @@ segments, info = model.transcribe("audio.mp3", beam_size=5, language="en")
* Python 3.8 or greater
Unlike openai-whisper, FFmpeg does **not** need to be installed on the system. The audio is decoded with the Python library [PyAV](https://github.com/PyAV-Org/PyAV) which bundles the FFmpeg libraries in its package.
### GPU
@@ -166,6 +165,35 @@ for segment in segments:
segments, _ = model.transcribe("audio.mp3")
segments = list(segments) # The transcription will actually run here.
```
### multi-segment language detection
To directly use the model for improved language detection, the following code snippet can be used:
```python
from faster_whisper import WhisperModel
model = WhisperModel("medium", device="cuda", compute_type="float16")
language_info = model.detect_language_multi_segment("audio.mp3")
```
### Batched faster-whisper
The batched version of faster-whisper is inspired by [whisper-x](https://github.com/m-bain/whisperX) licensed under the BSD-2 Clause license and integrates its VAD model to this library. We modify this implementation and also replaced the feature extraction with a faster torch-based implementation. Batched version improves the speed upto 10-12x compared to openAI implementation and 3-4x compared to the sequential faster_whisper version. It works by transcribing semantically meaningful audio chunks as batches leading to faster inference.
The following code snippet illustrates how to run inference with batched version on an example audio file. Please also refer to the test scripts of batched faster whisper.
```python
from faster_whisper import WhisperModel, BatchedInferencePipeline
model = WhisperModel("medium", device="cuda", compute_type="float16")
batched_model = BatchedInferencePipeline(model=model)
segments, info = batched_model.transcribe("audio.mp3", batch_size=16)
for segment in segments:
print("[%.2fs -> %.2fs] %s" % (segment.start, segment.end, segment.text))
```
### Faster Distil-Whisper
The Distil-Whisper checkpoints are compatible with the Faster-Whisper package. In particular, the latest [distil-large-v3](https://huggingface.co/distil-whisper/distil-large-v3)

View File

@@ -1,5 +1,6 @@
import argparse
import json
import os
from datasets import load_dataset
from evaluate import load
@@ -26,7 +27,9 @@ dataset = load_dataset("librispeech_asr", "clean", split="validation", streaming
# define the evaluation metric
wer_metric = load("wer")
normalizer = EnglishTextNormalizer(json.load(open("normalizer.json")))
with open(os.path.join(os.path.dirname(__file__), "normalizer.json"), "r") as f:
normalizer = EnglishTextNormalizer(json.load(f))
def inference(batch):

View File

@@ -1,5 +1,5 @@
from faster_whisper.audio import decode_audio
from faster_whisper.transcribe import WhisperModel
from faster_whisper.transcribe import BatchedInferencePipeline, WhisperModel
from faster_whisper.utils import available_models, download_model, format_timestamp
from faster_whisper.version import __version__
@@ -7,6 +7,7 @@ __all__ = [
"available_models",
"decode_audio",
"WhisperModel",
"BatchedInferencePipeline",
"download_model",
"format_timestamp",
"__version__",

Binary file not shown.

View File

@@ -1,19 +1,7 @@
"""We use the PyAV library to decode the audio: https://github.com/PyAV-Org/PyAV
The advantage of PyAV is that it bundles the FFmpeg libraries so there is no additional
system dependencies. FFmpeg does not need to be installed on the system.
However, the API is quite low-level so we need to manipulate audio frames directly.
"""
import gc
import io
import itertools
from typing import BinaryIO, Union
import av
import numpy as np
import torch
import torchaudio
def decode_audio(
@@ -29,91 +17,42 @@ def decode_audio(
split_stereo: Return separate left and right channels.
Returns:
A float32 Numpy array.
A float32 Torch Tensor.
If `split_stereo` is enabled, the function returns a 2-tuple with the
separated left and right channels.
"""
resampler = av.audio.resampler.AudioResampler(
format="s16",
layout="mono" if not split_stereo else "stereo",
rate=sampling_rate,
waveform, audio_sf = torchaudio.load(input_file) # waveform: channels X T
if audio_sf != sampling_rate:
waveform = torchaudio.functional.resample(
waveform, orig_freq=audio_sf, new_freq=sampling_rate
)
raw_buffer = io.BytesIO()
dtype = None
with av.open(input_file, mode="r", metadata_errors="ignore") as container:
frames = container.decode(audio=0)
frames = _ignore_invalid_frames(frames)
frames = _group_frames(frames, 500000)
frames = _resample_frames(frames, resampler)
for frame in frames:
array = frame.to_ndarray()
dtype = array.dtype
raw_buffer.write(array)
# It appears that some objects related to the resampler are not freed
# unless the garbage collector is manually run.
del resampler
gc.collect()
audio = np.frombuffer(raw_buffer.getbuffer(), dtype=dtype)
# Convert s16 back to f32.
audio = audio.astype(np.float32) / 32768.0
if split_stereo:
left_channel = audio[0::2]
right_channel = audio[1::2]
return left_channel, right_channel
return waveform[0], waveform[1]
return audio
def _ignore_invalid_frames(frames):
iterator = iter(frames)
while True:
try:
yield next(iterator)
except StopIteration:
break
except av.error.InvalidDataError:
continue
def _group_frames(frames, num_samples=None):
fifo = av.audio.fifo.AudioFifo()
for frame in frames:
frame.pts = None # Ignore timestamp check.
fifo.write(frame)
if num_samples is not None and fifo.samples >= num_samples:
yield fifo.read()
if fifo.samples > 0:
yield fifo.read()
def _resample_frames(frames, resampler):
# Add None to flush the resampler.
for frame in itertools.chain(frames, [None]):
yield from resampler.resample(frame)
return waveform.mean(0)
def pad_or_trim(array, length: int, *, axis: int = -1):
"""
Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
"""
axis = axis % array.ndim
if array.shape[axis] > length:
array = array.take(indices=range(length), axis=axis)
idx = [Ellipsis] * axis + [slice(length)] + [Ellipsis] * (array.ndim - axis - 1)
return array[idx]
if array.shape[axis] < length:
pad_widths = [(0, 0)] * array.ndim
pad_widths[axis] = (0, length - array.shape[axis])
array = np.pad(array, pad_widths)
pad_widths = (
[
0,
]
* array.ndim
* 2
)
pad_widths[2 * axis] = length - array.shape[axis]
array = torch.nn.functional.pad(array, tuple(pad_widths[::-1]))
return array

View File

@@ -1,16 +1,21 @@
import numpy as np
import torch
# Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/feature_extraction_whisper.py # noqa: E501
class FeatureExtractor:
def __init__(
self,
device: str = "auto",
feature_size=80,
sampling_rate=16000,
hop_length=160,
chunk_length=30,
n_fft=400,
):
if device == "auto":
self.device = "cuda" if torch.cuda.is_available() else "cpu"
else:
self.device = device
self.n_fft = n_fft
self.hop_length = hop_length
self.chunk_length = chunk_length
@@ -22,21 +27,22 @@ class FeatureExtractor:
sampling_rate, n_fft, n_mels=feature_size
)
def get_mel_filters(self, sr, n_fft, n_mels=128, dtype=np.float32):
@staticmethod
def get_mel_filters(sr, n_fft, n_mels=128):
"""
Implementation of librosa.filters.mel in Pytorch
"""
# Initialize the weights
n_mels = int(n_mels)
weights = np.zeros((n_mels, int(1 + n_fft // 2)), dtype=dtype)
# Center freqs of each FFT bin
fftfreqs = np.fft.rfftfreq(n=n_fft, d=1.0 / sr)
fftfreqs = torch.fft.rfftfreq(n=n_fft, d=1.0 / sr)
# 'Center freqs' of mel bands - uniformly spaced between limits
min_mel = 0.0
max_mel = 45.245640471924965
mels = np.linspace(min_mel, max_mel, n_mels + 2)
mels = np.asanyarray(mels)
mels = torch.linspace(min_mel, max_mel, n_mels + 2)
# Fill in the linear scale
f_min = 0.0
@@ -46,125 +52,63 @@ class FeatureExtractor:
# And now the nonlinear scale
min_log_hz = 1000.0 # beginning of log region (Hz)
min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
logstep = np.log(6.4) / 27.0 # step size for log region
logstep = torch.log(torch.tensor(6.4)) / 27.0 # step size for log region
# If we have vector data, vectorize
log_t = mels >= min_log_mel
freqs[log_t] = min_log_hz * np.exp(logstep * (mels[log_t] - min_log_mel))
freqs[log_t] = min_log_hz * torch.exp(logstep * (mels[log_t] - min_log_mel))
mel_f = freqs
fdiff = np.diff(mel_f)
ramps = np.subtract.outer(mel_f, fftfreqs)
fdiff = torch.diff(mel_f)
ramps = mel_f.view(-1, 1) - fftfreqs.view(1, -1)
for i in range(n_mels):
# lower and upper slopes for all bins
lower = -ramps[i] / fdiff[i]
upper = ramps[i + 2] / fdiff[i + 1]
lower = -ramps[:-2] / fdiff[:-1].unsqueeze(1)
upper = ramps[2:] / fdiff[1:].unsqueeze(1)
# .. then intersect them with each other and zero
weights[i] = np.maximum(0, np.minimum(lower, upper))
# Intersect them with each other and zero, vectorized across all i
weights = torch.maximum(torch.zeros_like(lower), torch.minimum(lower, upper))
# Slaney-style mel is scaled to be approx constant energy per channel
enorm = 2.0 / (mel_f[2 : n_mels + 2] - mel_f[:n_mels])
weights *= enorm[:, np.newaxis]
weights *= enorm.unsqueeze(1)
return weights
def fram_wave(self, waveform, center=True):
def __call__(self, waveform, padding=True, chunk_length=None, to_cpu=False):
"""
Transform a raw waveform into a list of smaller waveforms.
The window length defines how much of the signal is
contain in each frame (smalle waveform), while the hope length defines the step
between the beginning of each new frame.
Centering is done by reflecting the waveform which is first centered around
`frame_idx * hop_length`.
Compute the log-Mel spectrogram of the provided audio.
"""
frames = []
for i in range(0, waveform.shape[0] + 1, self.hop_length):
half_window = (self.n_fft - 1) // 2 + 1
if center:
start = i - half_window if i > half_window else 0
end = (
i + half_window
if i < waveform.shape[0] - half_window
else waveform.shape[0]
)
frame = waveform[start:end]
if start == 0:
padd_width = (-i + half_window, 0)
frame = np.pad(frame, pad_width=padd_width, mode="reflect")
elif end == waveform.shape[0]:
padd_width = (0, (i - waveform.shape[0] + half_window))
frame = np.pad(frame, pad_width=padd_width, mode="reflect")
else:
frame = waveform[i : i + self.n_fft]
frame_width = frame.shape[0]
if frame_width < waveform.shape[0]:
frame = np.lib.pad(
frame,
pad_width=(0, self.n_fft - frame_width),
mode="constant",
constant_values=0,
)
frames.append(frame)
return np.stack(frames, 0)
def stft(self, frames, window):
"""
Calculates the complex Short-Time Fourier Transform (STFT) of the given framed signal.
Should give the same results as `torch.stft`.
"""
frame_size = frames.shape[1]
fft_size = self.n_fft
if fft_size is None:
fft_size = frame_size
if fft_size < frame_size:
raise ValueError("FFT size must greater or equal the frame size")
# number of FFT bins to store
num_fft_bins = (fft_size >> 1) + 1
data = np.empty((len(frames), num_fft_bins), dtype=np.complex64)
fft_signal = np.zeros(fft_size)
for f, frame in enumerate(frames):
if window is not None:
np.multiply(frame, window, out=fft_signal[:frame_size])
else:
fft_signal[:frame_size] = frame
data[f] = np.fft.fft(fft_signal, axis=0)[:num_fft_bins]
return data.T
def __call__(self, waveform, padding=True, chunk_length=None):
"""
Compute the log-Mel spectrogram of the provided audio, gives similar results
whisper's original torch implementation with 1e-5 tolerance.
"""
if chunk_length is not None:
self.n_samples = chunk_length * self.sampling_rate
self.nb_max_frames = self.n_samples // self.hop_length
if waveform.dtype is not torch.float32:
waveform = waveform.to(torch.float32)
waveform = (
waveform.to(self.device)
if self.device == "cuda" and not waveform.is_cuda
else waveform
)
if padding:
waveform = np.pad(waveform, [(0, self.n_samples)])
waveform = torch.nn.functional.pad(waveform, (0, self.n_samples))
window = np.hanning(self.n_fft + 1)[:-1]
window = torch.hann_window(self.n_fft).to(waveform.device)
frames = self.fram_wave(waveform)
stft = self.stft(frames, window=window)
magnitudes = np.abs(stft[:, :-1]) ** 2
stft = torch.stft(
waveform, self.n_fft, self.hop_length, window=window, return_complex=True
)
magnitudes = stft[..., :-1].abs() ** 2
filters = self.mel_filters
mel_spec = filters @ magnitudes
mel_spec = self.mel_filters.to(waveform.device) @ magnitudes
log_spec = np.log10(np.clip(mel_spec, a_min=1e-10, a_max=None))
log_spec = np.maximum(log_spec, log_spec.max() - 8.0)
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0
return log_spec
# When the model is running on multiple GPUs, the output should be moved
# to the CPU since we don't know which GPU will handle the next job.
return log_spec.cpu() if to_cpu else log_spec

File diff suppressed because it is too large Load Diff

View File

@@ -2,9 +2,17 @@ import bisect
import functools
import os
from typing import List, NamedTuple, Optional
from abc import ABC
from collections.abc import Callable
from typing import List, NamedTuple, Optional, Union
import numpy as np
import torch
from pyannote.audio.core.io import AudioFile
from pyannote.audio.pipelines import VoiceActivityDetection
from pyannote.audio.pipelines.utils import PipelineModel
from pyannote.core import Annotation, Segment, SlidingWindowFeature
from faster_whisper.utils import get_assets_path
@@ -35,7 +43,7 @@ class VadOptions(NamedTuple):
def get_speech_timestamps(
audio: np.ndarray,
audio: torch.Tensor,
vad_options: Optional[VadOptions] = None,
**kwargs,
) -> List[dict]:
@@ -176,12 +184,12 @@ def get_speech_timestamps(
return speeches
def collect_chunks(audio: np.ndarray, chunks: List[dict]) -> np.ndarray:
def collect_chunks(audio: torch.Tensor, chunks: List[dict]) -> torch.Tensor:
"""Collects and concatenates audio chunks."""
if not chunks:
return np.array([], dtype=np.float32)
return torch.tensor([], dtype=torch.float32)
return np.concatenate([audio[chunk["start"] : chunk["end"]] for chunk in chunks])
return torch.cat([audio[chunk["start"] : chunk["end"]] for chunk in chunks])
class SpeechTimestampsMap:
@@ -276,3 +284,313 @@ class SileroVADModel:
context = x[..., -64:]
return out, state, context
# BSD 2-Clause License
# Copyright (c) 2024, Max Bain
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# The code below is copied from whisper-x (https://github.com/m-bain/whisperX)
# and adapted for faster_whisper.
class SegmentX:
def __init__(self, start, end, speaker=None):
self.start = start
self.end = end
self.speaker = speaker
class VoiceActivitySegmentation(VoiceActivityDetection, ABC):
"""Pipeline wrapper class for Voice Activity Segmentation based on VAD scores."""
def __init__(
self,
segmentation: PipelineModel = "pyannote/segmentation",
device: Optional[Union[str, torch.device]] = None,
fscore: bool = False,
use_auth_token: Optional[str] = None,
**inference_kwargs,
):
"""Initialize the pipeline with the model name and the optional device.
Args:
dict parameters of VoiceActivityDetection class from pyannote:
segmentation (PipelineModel): Loaded model name.
device (torch.device or None): Device to perform the segmentation.
fscore (bool): Flag indicating whether to compute F-score during inference.
use_auth_token (str or None): Optional authentication token for model access.
inference_kwargs (dict): Additional arguments from VoiceActivityDetection pipeline.
"""
super().__init__(
segmentation=segmentation,
device=device,
fscore=fscore,
use_auth_token=use_auth_token,
**inference_kwargs,
)
def apply(
self, file: AudioFile, hook: Optional[Callable] = None
) -> SlidingWindowFeature:
"""Apply voice activity detection on the audio file.
Args:
file (AudioFile): Processed file.
hook (callable): Hook called with signature: hook("step_name", step_artefact, file=file)
Returns:
segmentations (SlidingWindowFeature): Voice activity segmentation.
"""
# setup hook (e.g. for debugging purposes)
hook = self.setup_hook(file, hook=hook)
# apply segmentation model if needed
# output shape is (num_chunks, num_frames, 1)
if self.training:
if self.CACHED_SEGMENTATION in file:
segmentations = file[self.CACHED_SEGMENTATION]
else:
segmentations = self._segmentation(file)
file[self.CACHED_SEGMENTATION] = segmentations
else:
segmentations: SlidingWindowFeature = self._segmentation(file)
return segmentations
class BinarizeVadScores:
"""Binarize detection scores using hysteresis thresholding.
Reference:
Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of
RNN-based Voice Activity Detection", InterSpeech 2015.
Modified by Max Bain to include WhisperX's min-cut operation
https://arxiv.org/abs/2303.00747
"""
def __init__(
self,
onset: float = 0.5,
offset: Optional[float] = None,
min_duration_on: float = 0.0,
min_duration_off: float = 0.0,
pad_onset: float = 0.0,
pad_offset: float = 0.0,
max_duration: float = float("inf"),
):
"""Initializes the parameters for Binarizing the VAD scores.
Args:
onset (float, optional):
Onset threshold. Defaults to 0.5.
offset (float, optional):
Offset threshold. Defaults to `onset`.
min_duration_on (float, optional):
Remove active regions shorter than that many seconds. Defaults to 0s.
min_duration_off (float, optional):
Fill inactive regions shorter than that many seconds. Defaults to 0s.
pad_onset (float, optional):
Extend active regions by moving their start time by that many seconds.
Defaults to 0s.
pad_offset (float, optional):
Extend active regions by moving their end time by that many seconds.
Defaults to 0s.
max_duration (float):
The maximum length of an active segment.
"""
super().__init__()
self.onset = onset
self.offset = offset or onset
self.pad_onset = pad_onset
self.pad_offset = pad_offset
self.min_duration_on = min_duration_on
self.min_duration_off = min_duration_off
self.max_duration = max_duration
def __get_active_regions(self, scores: SlidingWindowFeature) -> Annotation:
"""Extract active regions from VAD scores.
Args:
scores (SlidingWindowFeature): Detection scores.
Returns:
active (Annotation): Active regions.
"""
num_frames, num_classes = scores.data.shape
frames = scores.sliding_window
timestamps = [frames[i].middle for i in range(num_frames)]
# annotation meant to store 'active' regions
active = Annotation()
for k, k_scores in enumerate(scores.data.T):
label = k if scores.labels is None else scores.labels[k]
# initial state
start = timestamps[0]
is_active = k_scores[0] > self.onset
curr_scores = [k_scores[0]]
curr_timestamps = [start]
t = start
# optionally add `strict=False` for python 3.10 or later
for t, y in zip(timestamps[1:], k_scores[1:]):
# currently active
if is_active:
curr_duration = t - start
if curr_duration > self.max_duration:
search_after = len(curr_scores) // 2
# divide segment
min_score_div_idx = search_after + np.argmin(
curr_scores[search_after:]
)
min_score_t = curr_timestamps[min_score_div_idx]
region = Segment(
start - self.pad_onset, min_score_t + self.pad_offset
)
active[region, k] = label
start = curr_timestamps[min_score_div_idx]
curr_scores = curr_scores[min_score_div_idx + 1 :]
curr_timestamps = curr_timestamps[min_score_div_idx + 1 :]
# switching from active to inactive
elif y < self.offset:
region = Segment(start - self.pad_onset, t + self.pad_offset)
active[region, k] = label
start = t
is_active = False
curr_scores = []
curr_timestamps = []
curr_scores.append(y)
curr_timestamps.append(t)
# currently inactive
else:
# switching from inactive to active
if y > self.onset:
start = t
is_active = True
# if active at the end, add final region
if is_active:
region = Segment(start - self.pad_onset, t + self.pad_offset)
active[region, k] = label
return active
def __call__(self, scores: SlidingWindowFeature) -> Annotation:
"""Binarize detection scores.
Args:
scores (SlidingWindowFeature): Detection scores.
Returns:
active (Annotation): Binarized scores.
"""
active = self.__get_active_regions(scores)
# because of padding, some active regions might be overlapping: merge them.
# also: fill same speaker gaps shorter than min_duration_off
if self.pad_offset > 0.0 or self.pad_onset > 0.0 or self.min_duration_off > 0.0:
if self.max_duration < float("inf"):
raise NotImplementedError("This would break current max_duration param")
active = active.support(collar=self.min_duration_off)
# remove tracks shorter than min_duration_on
if self.min_duration_on > 0:
for segment, track in list(active.itertracks()):
if segment.duration < self.min_duration_on:
del active[segment, track]
return active
def merge_chunks(
segments,
chunk_length,
onset: float = 0.5,
offset: Optional[float] = None,
edge_padding: float = 0.1,
):
"""
Merge operation described in whisper-x paper
"""
curr_end = 0
merged_segments = []
seg_idxs = []
speaker_idxs = []
assert chunk_length > 0
binarize = BinarizeVadScores(max_duration=chunk_length, onset=onset, offset=offset)
segments = binarize(segments)
segments_list = []
for speech_turn in segments.get_timeline():
segments_list.append(
SegmentX(
max(0.0, speech_turn.start - edge_padding),
speech_turn.end + edge_padding,
"UNKNOWN",
)
) # 100ms edge padding to account for edge errors
if len(segments_list) == 0:
print("No active speech found in audio")
return []
# Make sur the starting point is the start of the segment.
curr_start = segments_list[0].start
for idx, seg in enumerate(segments_list):
# if any segment start timing is less than previous segment end timing,
# reset the edge padding. Similarly for end timing.
if idx > 0:
if seg.start < segments_list[idx - 1].end:
seg.start += edge_padding
if idx < len(segments_list) - 1:
if seg.end > segments_list[idx + 1].start:
seg.end -= edge_padding
if seg.end - curr_start > chunk_length and curr_end - curr_start > 0:
merged_segments.append(
{
"start": curr_start,
"end": curr_end,
"segments": seg_idxs,
}
)
curr_start = seg.start
seg_idxs = []
speaker_idxs = []
curr_end = seg.end
seg_idxs.append((seg.start, seg.end))
speaker_idxs.append(seg.speaker)
# add final
merged_segments.append(
{
"start": curr_start,
"end": curr_end,
"segments": seg_idxs,
}
)
return merged_segments

View File

@@ -1,5 +1,8 @@
av>=11.0,<13
ctranslate2>=4.0,<5
huggingface_hub>=0.13
tokenizers>=0.13,<1
onnxruntime>=1.14,<2
pyannote-audio>=3.1.1
torch>=2.1.1
torchaudio>=2.1.2
tqdm

View File

@@ -11,3 +11,8 @@ def data_dir():
@pytest.fixture
def jfk_path(data_dir):
return os.path.join(data_dir, "jfk.flac")
@pytest.fixture
def physcisworks_path(data_dir):
return os.path.join(data_dir, "physicsworks.wav")

BIN
tests/data/physicsworks.wav Normal file

Binary file not shown.

View File

@@ -1,6 +1,6 @@
import os
from faster_whisper import WhisperModel, decode_audio
from faster_whisper import BatchedInferencePipeline, WhisperModel, decode_audio
from faster_whisper.tokenizer import Tokenizer
from faster_whisper.transcribe import get_suppressed_tokens
@@ -39,6 +39,50 @@ def test_transcribe(jfk_path):
assert segment.text == "".join(word.word for word in segment.words)
assert segment.start == segment.words[0].start
assert segment.end == segment.words[-1].end
batched_model = BatchedInferencePipeline(model=model, use_vad_model=False)
result, info = batched_model.transcribe(jfk_path, word_timestamps=True)
assert info.language == "en"
assert info.language_probability > 0.7
segments = []
for segment in result:
segments.append(
{"start": segment.start, "end": segment.end, "text": segment.text}
)
assert len(segments) == 1
assert segment.text == (
" And so my fellow Americans ask not what your country can do for you, "
"ask what you can do for your country."
)
def test_batched_transcribe(physcisworks_path):
model = WhisperModel("tiny")
batched_model = BatchedInferencePipeline(model=model)
result, info = batched_model.transcribe(physcisworks_path, batch_size=16)
assert info.language == "en"
assert info.language_probability > 0.7
segments = []
for segment in result:
segments.append(
{"start": segment.start, "end": segment.end, "text": segment.text}
)
# number of near 30 sec segments
assert len(segments) == 8
result, info = batched_model.transcribe(
physcisworks_path,
batch_size=16,
without_timestamps=False,
word_timestamps=True,
)
segments = []
for segment in result:
assert segment.words is not None
segments.append(
{"start": segment.start, "end": segment.end, "text": segment.text}
)
assert len(segments) > 8
def test_prefix_with_timestamps(jfk_path):
@@ -101,6 +145,13 @@ def test_stereo_diarization(data_dir):
assert transcription == "The horizon seems extremely distant."
def test_multisegment_lang_id(physcisworks_path):
model = WhisperModel("tiny")
language_info = model.detect_language_multi_segment(physcisworks_path)
assert language_info["language_code"] == "en"
assert language_info["language_confidence"] > 0.8
def test_suppressed_tokens_minus_1():
model = WhisperModel("tiny.en")