New PR for Faster Whisper: Batching Support, Speed Boosts, and Quality Enhancements (#856)
Batching Support, Speed Boosts, and Quality Enhancements --------- Co-authored-by: Hargun Mujral <83234565+hargunmujral@users.noreply.github.com> Co-authored-by: MahmoudAshraf97 <hassouna97.ma@gmail.com>
This commit is contained in:
@@ -1,16 +1,21 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
# Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/feature_extraction_whisper.py # noqa: E501
|
||||
class FeatureExtractor:
|
||||
def __init__(
|
||||
self,
|
||||
device: str = "auto",
|
||||
feature_size=80,
|
||||
sampling_rate=16000,
|
||||
hop_length=160,
|
||||
chunk_length=30,
|
||||
n_fft=400,
|
||||
):
|
||||
if device == "auto":
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
else:
|
||||
self.device = device
|
||||
self.n_fft = n_fft
|
||||
self.hop_length = hop_length
|
||||
self.chunk_length = chunk_length
|
||||
@@ -22,21 +27,22 @@ class FeatureExtractor:
|
||||
sampling_rate, n_fft, n_mels=feature_size
|
||||
)
|
||||
|
||||
def get_mel_filters(self, sr, n_fft, n_mels=128, dtype=np.float32):
|
||||
@staticmethod
|
||||
def get_mel_filters(sr, n_fft, n_mels=128):
|
||||
"""
|
||||
Implementation of librosa.filters.mel in Pytorch
|
||||
"""
|
||||
# Initialize the weights
|
||||
n_mels = int(n_mels)
|
||||
weights = np.zeros((n_mels, int(1 + n_fft // 2)), dtype=dtype)
|
||||
|
||||
# Center freqs of each FFT bin
|
||||
fftfreqs = np.fft.rfftfreq(n=n_fft, d=1.0 / sr)
|
||||
fftfreqs = torch.fft.rfftfreq(n=n_fft, d=1.0 / sr)
|
||||
|
||||
# 'Center freqs' of mel bands - uniformly spaced between limits
|
||||
min_mel = 0.0
|
||||
max_mel = 45.245640471924965
|
||||
|
||||
mels = np.linspace(min_mel, max_mel, n_mels + 2)
|
||||
|
||||
mels = np.asanyarray(mels)
|
||||
mels = torch.linspace(min_mel, max_mel, n_mels + 2)
|
||||
|
||||
# Fill in the linear scale
|
||||
f_min = 0.0
|
||||
@@ -46,125 +52,63 @@ class FeatureExtractor:
|
||||
# And now the nonlinear scale
|
||||
min_log_hz = 1000.0 # beginning of log region (Hz)
|
||||
min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
|
||||
logstep = np.log(6.4) / 27.0 # step size for log region
|
||||
logstep = torch.log(torch.tensor(6.4)) / 27.0 # step size for log region
|
||||
|
||||
# If we have vector data, vectorize
|
||||
log_t = mels >= min_log_mel
|
||||
freqs[log_t] = min_log_hz * np.exp(logstep * (mels[log_t] - min_log_mel))
|
||||
freqs[log_t] = min_log_hz * torch.exp(logstep * (mels[log_t] - min_log_mel))
|
||||
|
||||
mel_f = freqs
|
||||
|
||||
fdiff = np.diff(mel_f)
|
||||
ramps = np.subtract.outer(mel_f, fftfreqs)
|
||||
fdiff = torch.diff(mel_f)
|
||||
ramps = mel_f.view(-1, 1) - fftfreqs.view(1, -1)
|
||||
|
||||
for i in range(n_mels):
|
||||
# lower and upper slopes for all bins
|
||||
lower = -ramps[i] / fdiff[i]
|
||||
upper = ramps[i + 2] / fdiff[i + 1]
|
||||
lower = -ramps[:-2] / fdiff[:-1].unsqueeze(1)
|
||||
upper = ramps[2:] / fdiff[1:].unsqueeze(1)
|
||||
|
||||
# .. then intersect them with each other and zero
|
||||
weights[i] = np.maximum(0, np.minimum(lower, upper))
|
||||
# Intersect them with each other and zero, vectorized across all i
|
||||
weights = torch.maximum(torch.zeros_like(lower), torch.minimum(lower, upper))
|
||||
|
||||
# Slaney-style mel is scaled to be approx constant energy per channel
|
||||
enorm = 2.0 / (mel_f[2 : n_mels + 2] - mel_f[:n_mels])
|
||||
weights *= enorm[:, np.newaxis]
|
||||
weights *= enorm.unsqueeze(1)
|
||||
|
||||
return weights
|
||||
|
||||
def fram_wave(self, waveform, center=True):
|
||||
def __call__(self, waveform, padding=True, chunk_length=None, to_cpu=False):
|
||||
"""
|
||||
Transform a raw waveform into a list of smaller waveforms.
|
||||
The window length defines how much of the signal is
|
||||
contain in each frame (smalle waveform), while the hope length defines the step
|
||||
between the beginning of each new frame.
|
||||
Centering is done by reflecting the waveform which is first centered around
|
||||
`frame_idx * hop_length`.
|
||||
Compute the log-Mel spectrogram of the provided audio.
|
||||
"""
|
||||
frames = []
|
||||
for i in range(0, waveform.shape[0] + 1, self.hop_length):
|
||||
half_window = (self.n_fft - 1) // 2 + 1
|
||||
if center:
|
||||
start = i - half_window if i > half_window else 0
|
||||
end = (
|
||||
i + half_window
|
||||
if i < waveform.shape[0] - half_window
|
||||
else waveform.shape[0]
|
||||
)
|
||||
|
||||
frame = waveform[start:end]
|
||||
|
||||
if start == 0:
|
||||
padd_width = (-i + half_window, 0)
|
||||
frame = np.pad(frame, pad_width=padd_width, mode="reflect")
|
||||
|
||||
elif end == waveform.shape[0]:
|
||||
padd_width = (0, (i - waveform.shape[0] + half_window))
|
||||
frame = np.pad(frame, pad_width=padd_width, mode="reflect")
|
||||
|
||||
else:
|
||||
frame = waveform[i : i + self.n_fft]
|
||||
frame_width = frame.shape[0]
|
||||
if frame_width < waveform.shape[0]:
|
||||
frame = np.lib.pad(
|
||||
frame,
|
||||
pad_width=(0, self.n_fft - frame_width),
|
||||
mode="constant",
|
||||
constant_values=0,
|
||||
)
|
||||
|
||||
frames.append(frame)
|
||||
return np.stack(frames, 0)
|
||||
|
||||
def stft(self, frames, window):
|
||||
"""
|
||||
Calculates the complex Short-Time Fourier Transform (STFT) of the given framed signal.
|
||||
Should give the same results as `torch.stft`.
|
||||
"""
|
||||
frame_size = frames.shape[1]
|
||||
fft_size = self.n_fft
|
||||
|
||||
if fft_size is None:
|
||||
fft_size = frame_size
|
||||
|
||||
if fft_size < frame_size:
|
||||
raise ValueError("FFT size must greater or equal the frame size")
|
||||
# number of FFT bins to store
|
||||
num_fft_bins = (fft_size >> 1) + 1
|
||||
|
||||
data = np.empty((len(frames), num_fft_bins), dtype=np.complex64)
|
||||
fft_signal = np.zeros(fft_size)
|
||||
|
||||
for f, frame in enumerate(frames):
|
||||
if window is not None:
|
||||
np.multiply(frame, window, out=fft_signal[:frame_size])
|
||||
else:
|
||||
fft_signal[:frame_size] = frame
|
||||
data[f] = np.fft.fft(fft_signal, axis=0)[:num_fft_bins]
|
||||
return data.T
|
||||
|
||||
def __call__(self, waveform, padding=True, chunk_length=None):
|
||||
"""
|
||||
Compute the log-Mel spectrogram of the provided audio, gives similar results
|
||||
whisper's original torch implementation with 1e-5 tolerance.
|
||||
"""
|
||||
if chunk_length is not None:
|
||||
self.n_samples = chunk_length * self.sampling_rate
|
||||
self.nb_max_frames = self.n_samples // self.hop_length
|
||||
|
||||
if waveform.dtype is not torch.float32:
|
||||
waveform = waveform.to(torch.float32)
|
||||
|
||||
waveform = (
|
||||
waveform.to(self.device)
|
||||
if self.device == "cuda" and not waveform.is_cuda
|
||||
else waveform
|
||||
)
|
||||
|
||||
if padding:
|
||||
waveform = np.pad(waveform, [(0, self.n_samples)])
|
||||
waveform = torch.nn.functional.pad(waveform, (0, self.n_samples))
|
||||
|
||||
window = np.hanning(self.n_fft + 1)[:-1]
|
||||
window = torch.hann_window(self.n_fft).to(waveform.device)
|
||||
|
||||
frames = self.fram_wave(waveform)
|
||||
stft = self.stft(frames, window=window)
|
||||
magnitudes = np.abs(stft[:, :-1]) ** 2
|
||||
stft = torch.stft(
|
||||
waveform, self.n_fft, self.hop_length, window=window, return_complex=True
|
||||
)
|
||||
magnitudes = stft[..., :-1].abs() ** 2
|
||||
|
||||
filters = self.mel_filters
|
||||
mel_spec = filters @ magnitudes
|
||||
mel_spec = self.mel_filters.to(waveform.device) @ magnitudes
|
||||
|
||||
log_spec = np.log10(np.clip(mel_spec, a_min=1e-10, a_max=None))
|
||||
log_spec = np.maximum(log_spec, log_spec.max() - 8.0)
|
||||
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
||||
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
||||
log_spec = (log_spec + 4.0) / 4.0
|
||||
|
||||
return log_spec
|
||||
# When the model is running on multiple GPUs, the output should be moved
|
||||
# to the CPU since we don't know which GPU will handle the next job.
|
||||
return log_spec.cpu() if to_cpu else log_spec
|
||||
|
||||
Reference in New Issue
Block a user