Initial commit

This commit is contained in:
Guillaume Klein
2023-02-11 10:21:19 +01:00
commit 5216d52d94
9 changed files with 658 additions and 0 deletions

21
LICENSE Normal file
View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2023 Guillaume Klein
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

63
README.md Normal file
View File

@@ -0,0 +1,63 @@
# Faster Whisper transcription with CTranslate2
This repository demonstrates how to implement the Whisper transcription using [CTranslate2](https://github.com/OpenNMT/CTranslate2/), which is a fast inference engine for Transformer models.
This implementation is about 4 times faster than [openai/whisper](https://github.com/openai/whisper) for the same accuracy while using less memory. The efficiency can be further improved with 8-bit quantization on both CPU and GPU.
## Installation
```bash
pip install -e .[conversion]
```
The model conversion requires the modules `transformers` and `torch` which are installed by the `[conversion]` requirement. Once a model is converted, these modules are no longer needed and the installation could be simplified to:
```bash
pip install -e .
```
## Usage
### Model conversion
A Whisper model should be first converted into the CTranslate2 format. For example the command below converts the "medium" Whisper model and saves the weights in FP16:
```bash
ct2-transformers-converter --model openai/whisper-medium --output_dir whisper-medium-ct2 --quantization float16
```
If needed, models can also be converted from the code. See the [conversion API](https://opennmt.net/CTranslate2/python/ctranslate2.converters.TransformersConverter.html).
### Transcription
```python
from faster_whisper import WhisperModel
model_path = "whisper-medium-ct2/"
# Run on GPU with FP16
model = WhisperModel(model_path, device="cuda", compute_type="float16")
# or run on GPU with INT8
# model = WhisperModel(model_path, device="cuda", compute_type="int8_float16")
# or run on CPU with INT8
# model = WhisperModel(model_path, device="cpu", compute_type="int8")
segments, info = model.transcribe("audio.mp3", beam_size=5)
print("Detected language '%s' with probability %f" % (info.language, info.language_probability))
for segment in segments:
print("[%ds -> %ds] %s" % (segment.start, segment.end, segment.text))
```
## Comparing performance against openai/whisper
If you are comparing the performance against [openai/whisper](https://github.com/openai/whisper), you should make sure to use the same settings in both frameworks. In particular:
* In openai/whisper, `model.transcribe` uses a beam size of 1 by default. A different beam size will have an important impact on performance so make to use the same.
* When running on CPU, make sure to set the same number of threads. Both frameworks will read the environment variable `OMP_NUM_THREADS`, which can be set when running your script:
```bash
OMP_NUM_THREADS=4 python3 my_script.py
```

View File

@@ -0,0 +1 @@
from faster_whisper.transcribe import WhisperModel

36
faster_whisper/audio.py Normal file
View File

@@ -0,0 +1,36 @@
import av
import numpy as np
def decode_audio(input_file, sampling_rate=16000):
"""Decodes the audio.
Args:
input_file: Path to the input file or a file-like object.
sampling_rate: Resample the audio to this sample rate.
Returns:
A float32 Numpy array.
"""
fifo = av.audio.fifo.AudioFifo()
resampler = av.audio.resampler.AudioResampler(
format="s16",
layout="mono",
rate=sampling_rate,
)
with av.open(input_file) as container:
# Decode and resample each audio frame.
for frame in container.decode(audio=0):
frame.pts = None
for new_frame in resampler.resample(frame):
fifo.write(new_frame)
# Flush the resampler.
for new_frame in resampler.resample(None):
fifo.write(new_frame)
frame = fifo.read()
# Convert s16 back to f32.
return frame.to_ndarray().flatten().astype(np.float32) / 32768.0

View File

@@ -0,0 +1,163 @@
import numpy as np
# Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/feature_extraction_whisper.py
class FeatureExtractor:
def __init__(
self,
feature_size=80,
sampling_rate=16000,
hop_length=160,
chunk_length=30,
n_fft=400,
):
self.n_fft = n_fft
self.hop_length = hop_length
self.chunk_length = chunk_length
self.n_samples = chunk_length * sampling_rate
self.nb_max_frames = self.n_samples // hop_length
self.time_per_frame = hop_length / sampling_rate
self.sampling_rate = sampling_rate
self.mel_filters = self.get_mel_filters(
sampling_rate, n_fft, n_mels=feature_size
)
def get_mel_filters(self, sr, n_fft, n_mels=128, dtype=np.float32):
# Initialize the weights
n_mels = int(n_mels)
weights = np.zeros((n_mels, int(1 + n_fft // 2)), dtype=dtype)
# Center freqs of each FFT bin
fftfreqs = np.fft.rfftfreq(n=n_fft, d=1.0 / sr)
# 'Center freqs' of mel bands - uniformly spaced between limits
min_mel = 0.0
max_mel = 45.245640471924965
mels = np.linspace(min_mel, max_mel, n_mels + 2)
mels = np.asanyarray(mels)
# Fill in the linear scale
f_min = 0.0
f_sp = 200.0 / 3
freqs = f_min + f_sp * mels
# And now the nonlinear scale
min_log_hz = 1000.0 # beginning of log region (Hz)
min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
logstep = np.log(6.4) / 27.0 # step size for log region
# If we have vector data, vectorize
log_t = mels >= min_log_mel
freqs[log_t] = min_log_hz * np.exp(logstep * (mels[log_t] - min_log_mel))
mel_f = freqs
fdiff = np.diff(mel_f)
ramps = np.subtract.outer(mel_f, fftfreqs)
for i in range(n_mels):
# lower and upper slopes for all bins
lower = -ramps[i] / fdiff[i]
upper = ramps[i + 2] / fdiff[i + 1]
# .. then intersect them with each other and zero
weights[i] = np.maximum(0, np.minimum(lower, upper))
# Slaney-style mel is scaled to be approx constant energy per channel
enorm = 2.0 / (mel_f[2 : n_mels + 2] - mel_f[:n_mels])
weights *= enorm[:, np.newaxis]
return weights
def fram_wave(self, waveform, center=True):
"""
Transform a raw waveform into a list of smaller waveforms.
The window length defines how much of the signal is
contain in each frame (smalle waveform), while the hope length defines the step
between the beginning of each new frame.
Centering is done by reflecting the waveform which is first centered around
`frame_idx * hop_length`.
"""
frames = []
for i in range(0, waveform.shape[0] + 1, self.hop_length):
half_window = (self.n_fft - 1) // 2 + 1
if center:
start = i - half_window if i > half_window else 0
end = (
i + half_window
if i < waveform.shape[0] - half_window
else waveform.shape[0]
)
frame = waveform[start:end]
if start == 0:
padd_width = (-i + half_window, 0)
frame = np.pad(frame, pad_width=padd_width, mode="reflect")
elif end == waveform.shape[0]:
padd_width = (0, (i - waveform.shape[0] + half_window))
frame = np.pad(frame, pad_width=padd_width, mode="reflect")
else:
frame = waveform[i : i + self.n_fft]
frame_width = frame.shape[0]
if frame_width < waveform.shape[0]:
frame = np.lib.pad(
frame,
pad_width=(0, self.n_fft - frame_width),
mode="constant",
constant_values=0,
)
frames.append(frame)
return np.stack(frames, 0)
def stft(self, frames, window):
"""
Calculates the complex Short-Time Fourier Transform (STFT) of the given framed signal.
Should give the same results as `torch.stft`.
"""
frame_size = frames.shape[1]
fft_size = self.n_fft
if fft_size is None:
fft_size = frame_size
if fft_size < frame_size:
raise ValueError("FFT size must greater or equal the frame size")
# number of FFT bins to store
num_fft_bins = (fft_size >> 1) + 1
data = np.empty((len(frames), num_fft_bins), dtype=np.complex64)
fft_signal = np.zeros(fft_size)
for f, frame in enumerate(frames):
if window is not None:
np.multiply(frame, window, out=fft_signal[:frame_size])
else:
fft_signal[:frame_size] = frame
data[f] = np.fft.fft(fft_signal, axis=0)[:num_fft_bins]
return data.T
def __call__(self, waveform):
"""
Compute the log-Mel spectrogram of the provided audio, gives similar results
whisper's original torch implementation with 1e-5 tolerance.
"""
window = np.hanning(self.n_fft + 1)[:-1]
frames = self.fram_wave(waveform)
stft = self.stft(frames, window=window)
magnitudes = np.abs(stft[:, :-1]) ** 2
filters = self.mel_filters
mel_spec = filters @ magnitudes
log_spec = np.log10(np.clip(mel_spec, a_min=1e-10, a_max=None))
log_spec = np.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0
return log_spec

View File

@@ -0,0 +1,342 @@
import collections
import os
import zlib
import ctranslate2
import numpy as np
import tokenizers
from faster_whisper.audio import decode_audio
from faster_whisper.feature_extractor import FeatureExtractor
class Segment(collections.namedtuple("Segment", ("start", "end", "text"))):
pass
class AudioInfo(
collections.namedtuple("AudioInfo", ("language", "language_probability"))
):
pass
class TranscriptionOptions(
collections.namedtuple(
"TranscriptionOptions",
(
"beam_size",
"best_of",
"patience",
"log_prob_threshold",
"no_speech_threshold",
"compression_ratio_threshold",
"condition_on_previous_text",
"temperatures",
),
)
):
pass
class WhisperModel:
def __init__(
self,
model_path,
device="auto",
compute_type="default",
cpu_threads=0,
):
"""Initializes the Whisper model.
Args:
model_path: Path to the converted model.
device: Device to use for computation ("cpu", "cuda", "auto").
compute_type: Type to use for computation.
See https://opennmt.net/CTranslate2/quantization.html.
cpu_threads: Number of threads to use when running on CPU (4 by default).
On non zero value overrides the OMP_NUM_THREADS environment variable.
"""
self.model = ctranslate2.models.Whisper(
model_path,
device=device,
compute_type=compute_type,
intra_threads=cpu_threads,
)
self.feature_extractor = FeatureExtractor()
self.decoder = tokenizers.decoders.ByteLevel()
with open(os.path.join(model_path, "vocabulary.txt")) as vocab_file:
self.ids_to_tokens = [line.rstrip("\n") for line in vocab_file]
self.tokens_to_ids = {
token: i for i, token in enumerate(self.ids_to_tokens)
}
self.eot_id = self.tokens_to_ids["<|endoftext|>"]
self.timestamp_begin_id = self.tokens_to_ids["<|notimestamps|>"] + 1
self.input_stride = 2
self.time_precision = 0.02
self.max_length = 448
def transcribe(
self,
input_file,
language=None,
beam_size=5,
best_of=5,
patience=1,
temperature=[0.0, 0.2, 0.4, 0.6, 0.8, 1.0],
compression_ratio_threshold=2.4,
log_prob_threshold=-1.0,
no_speech_threshold=0.6,
condition_on_previous_text=True,
):
"""Transcribes an input file.
Arguments:
input_file: Path to the input file or a file-like object.
language: The language spoken in the audio. If not set, the language will be
detected in the first 30 seconds of audio.
beam_size: Beam size to use for decoding.
best_of: Number of candidates when sampling with non-zero temperature.
patience: Beam search patience factor.
temperature: Temperature for sampling. It can be a tuple of temperatures,
which will be successively used upon failures according to either
`compression_ratio_threshold` or `logprob_threshold`.
compression_ratio_threshold: If the gzip compression ratio is above this value,
treat as failed.
log_prob_threshold: If the average log probability over sampled tokens is
below this value, treat as failed.
no_speech_threshold: If the no_speech probability is higher than this value AND
the average log probability over sampled tokens is below `logprob_threshold`,
consider the segment as silent.
condition_on_previous_text: If True, the previous output of the model is provided
as a prompt for the next window; disabling may make the text inconsistent across
windows, but the model becomes less prone to getting stuck in a failure loop,
such as repetition looping or timestamps going out of sync.
Returns:
A tuple with:
- a generator over transcribed segments
- an instance of AudioInfo
"""
audio = decode_audio(
input_file, sampling_rate=self.feature_extractor.sampling_rate
)
features = self.feature_extractor(audio)
if language is None:
segment = self.get_segment(features)
input = self.get_input(segment)
results = self.model.detect_language(input)
language_token, language_probability = results[0][0]
language = language_token[2:-2]
else:
language_probability = 1
options = TranscriptionOptions(
beam_size=beam_size,
best_of=best_of,
patience=patience,
log_prob_threshold=log_prob_threshold,
no_speech_threshold=no_speech_threshold,
compression_ratio_threshold=compression_ratio_threshold,
condition_on_previous_text=condition_on_previous_text,
temperatures=(
temperature if isinstance(temperature, (list, tuple)) else [temperature]
),
)
segments = self.generate_segments(features, language, options)
audio_info = AudioInfo(
language=language,
language_probability=language_probability,
)
return segments, audio_info
def generate_segments(self, features, language, options):
tokenized_segments = self.generate_tokenized_segments(
features, language, options
)
for start, end, tokens in tokenized_segments:
text = self.decode_text_tokens(tokens)
if not text.strip():
continue
yield Segment(
start=start,
end=end,
text=text,
)
def generate_tokenized_segments(self, features, language, options):
num_frames = features.shape[-1]
offset = 0
all_tokens = []
prompt_reset_since = 0
while offset < num_frames:
time_offset = offset * self.feature_extractor.time_per_frame
segment = self.get_segment(features, offset)
segment_duration = segment.shape[-1] * self.feature_extractor.time_per_frame
previous_tokens = all_tokens[prompt_reset_since:]
prompt = self.get_prompt(language, previous_tokens)
result, temperature = self.generate_with_fallback(segment, prompt, options)
if (
result.no_speech_prob > options.no_speech_threshold
and result.scores[0] < options.log_prob_threshold
):
offset += segment.shape[-1]
continue
tokens = result.sequences_ids[0]
consecutive_timestamps = [
i
for i in range(len(tokens))
if i > 0
and tokens[i] >= self.timestamp_begin_id
and tokens[i - 1] >= self.timestamp_begin_id
]
if len(consecutive_timestamps) > 0:
last_slice = 0
for i, current_slice in enumerate(consecutive_timestamps):
sliced_tokens = tokens[last_slice:current_slice]
start_timestamp_position = (
sliced_tokens[0] - self.timestamp_begin_id
)
end_timestamp_position = sliced_tokens[-1] - self.timestamp_begin_id
start_time = (
time_offset + start_timestamp_position * self.time_precision
)
end_time = (
time_offset + end_timestamp_position * self.time_precision
)
last_in_window = i + 1 == len(consecutive_timestamps)
# Include the last timestamp so that all tokens are included in a segment.
if last_in_window:
sliced_tokens.append(tokens[current_slice])
yield start_time, end_time, sliced_tokens
last_slice = current_slice
last_timestamp_position = (
tokens[last_slice - 1] - self.timestamp_begin_id
)
offset += last_timestamp_position * self.input_stride
all_tokens.extend(tokens[: last_slice + 1])
else:
duration = segment_duration
timestamps = [
token for token in tokens if token >= self.timestamp_begin_id
]
if len(timestamps) > 0 and timestamps[-1] != self.timestamp_begin_id:
last_timestamp_position = timestamps[-1] - self.timestamp_begin_id
duration = last_timestamp_position * self.time_precision
yield time_offset, time_offset + duration, tokens
offset += segment.shape[-1]
all_tokens.extend(tokens)
if not options.condition_on_previous_text or temperature > 0.5:
prompt_reset_since = len(all_tokens)
def decode_text_tokens(self, tokens):
text_tokens = [
self.ids_to_tokens[token] for token in tokens if token < self.eot_id
]
return self.decoder.decode(text_tokens)
def generate_with_fallback(self, segment, prompt, options):
features = self.get_input(segment)
result = None
final_temperature = None
for temperature in options.temperatures:
if temperature > 0:
kwargs = {
"beam_size": 1,
"num_hypotheses": options.best_of,
"sampling_topk": 0,
"sampling_temperature": temperature,
}
else:
kwargs = {
"beam_size": options.beam_size,
"patience": options.patience,
}
final_temperature = temperature
result = self.model.generate(
features,
[prompt],
max_length=self.max_length,
return_scores=True,
return_no_speech_prob=True,
**kwargs,
)[0]
tokens = result.sequences_ids[0]
text = self.decode_text_tokens(tokens)
compression_ratio = get_compression_ratio(text)
if (
compression_ratio <= options.compression_ratio_threshold
and result.scores[0] >= options.log_prob_threshold
):
break
return result, final_temperature
def get_prompt(self, language, previous_tokens):
prompt = []
if previous_tokens:
prompt.append(self.tokens_to_ids["<|startofprev|>"])
prompt.extend(previous_tokens[-(self.max_length // 2 - 1) :])
prompt += [
self.tokens_to_ids["<|startoftranscript|>"],
self.tokens_to_ids["<|%s|>" % language],
self.tokens_to_ids["<|transcribe|>"],
]
return prompt
def get_segment(self, features, offset=0):
if offset > 0:
features = features[:, offset:]
num_frames = features.shape[-1]
required_num_frames = self.feature_extractor.nb_max_frames
if num_frames > required_num_frames:
features = features[:, :required_num_frames]
elif num_frames < required_num_frames:
pad_widths = [(0, 0), (0, required_num_frames - num_frames)]
features = np.pad(features, pad_widths)
features = np.ascontiguousarray(features)
return features
def get_input(self, segment):
segment = np.expand_dims(segment, 0)
segment = ctranslate2.StorageView.from_array(segment)
return segment
def get_compression_ratio(text):
text_bytes = text.encode("utf-8")
return len(text_bytes) / len(zlib.compress(text_bytes))

View File

@@ -0,0 +1 @@
transformers[torch]>=4.23

3
requirements.txt Normal file
View File

@@ -0,0 +1,3 @@
av==10.*
ctranslate2>=3.5,<4
tokenizers==0.13.*

28
setup.py Normal file
View File

@@ -0,0 +1,28 @@
import os
from setuptools import find_packages, setup
def get_requirements(path):
with open(path, encoding="utf-8") as requirements:
return [requirement.strip() for requirement in requirements]
base_dir = os.path.dirname(os.path.abspath(__file__))
install_requires = get_requirements(os.path.join(base_dir, "requirements.txt"))
conversion_requires = get_requirements(
os.path.join(base_dir, "requirements.conversion.txt")
)
setup(
name="faster-whisper",
version="0.1.0",
description="Faster Whisper transcription with CTranslate2",
author="Guillaume Klein",
python_requires=">=3.7",
install_requires=install_requires,
extras_require={
"conversion": conversion_requires,
},
packages=find_packages(),
)