From 5216d52d945b9182cee7ca30062d770198b03eca Mon Sep 17 00:00:00 2001 From: Guillaume Klein Date: Sat, 11 Feb 2023 10:21:19 +0100 Subject: [PATCH] Initial commit --- LICENSE | 21 ++ README.md | 63 +++++ faster_whisper/__init__.py | 1 + faster_whisper/audio.py | 36 +++ faster_whisper/feature_extractor.py | 163 +++++++++++++ faster_whisper/transcribe.py | 342 ++++++++++++++++++++++++++++ requirements.conversion.txt | 1 + requirements.txt | 3 + setup.py | 28 +++ 9 files changed, 658 insertions(+) create mode 100644 LICENSE create mode 100644 README.md create mode 100644 faster_whisper/__init__.py create mode 100644 faster_whisper/audio.py create mode 100644 faster_whisper/feature_extractor.py create mode 100644 faster_whisper/transcribe.py create mode 100644 requirements.conversion.txt create mode 100644 requirements.txt create mode 100644 setup.py diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..62f34be --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 Guillaume Klein + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..dab9637 --- /dev/null +++ b/README.md @@ -0,0 +1,63 @@ +# Faster Whisper transcription with CTranslate2 + +This repository demonstrates how to implement the Whisper transcription using [CTranslate2](https://github.com/OpenNMT/CTranslate2/), which is a fast inference engine for Transformer models. + +This implementation is about 4 times faster than [openai/whisper](https://github.com/openai/whisper) for the same accuracy while using less memory. The efficiency can be further improved with 8-bit quantization on both CPU and GPU. + +## Installation + +```bash +pip install -e .[conversion] +``` + +The model conversion requires the modules `transformers` and `torch` which are installed by the `[conversion]` requirement. Once a model is converted, these modules are no longer needed and the installation could be simplified to: + +```bash +pip install -e . +``` + +## Usage + +### Model conversion + +A Whisper model should be first converted into the CTranslate2 format. For example the command below converts the "medium" Whisper model and saves the weights in FP16: + +```bash +ct2-transformers-converter --model openai/whisper-medium --output_dir whisper-medium-ct2 --quantization float16 +``` + +If needed, models can also be converted from the code. See the [conversion API](https://opennmt.net/CTranslate2/python/ctranslate2.converters.TransformersConverter.html). + +### Transcription + +```python +from faster_whisper import WhisperModel + +model_path = "whisper-medium-ct2/" + +# Run on GPU with FP16 +model = WhisperModel(model_path, device="cuda", compute_type="float16") + +# or run on GPU with INT8 +# model = WhisperModel(model_path, device="cuda", compute_type="int8_float16") +# or run on CPU with INT8 +# model = WhisperModel(model_path, device="cpu", compute_type="int8") + +segments, info = model.transcribe("audio.mp3", beam_size=5) + +print("Detected language '%s' with probability %f" % (info.language, info.language_probability)) + +for segment in segments: + print("[%ds -> %ds] %s" % (segment.start, segment.end, segment.text)) +``` + +## Comparing performance against openai/whisper + +If you are comparing the performance against [openai/whisper](https://github.com/openai/whisper), you should make sure to use the same settings in both frameworks. In particular: + +* In openai/whisper, `model.transcribe` uses a beam size of 1 by default. A different beam size will have an important impact on performance so make to use the same. +* When running on CPU, make sure to set the same number of threads. Both frameworks will read the environment variable `OMP_NUM_THREADS`, which can be set when running your script: + +```bash +OMP_NUM_THREADS=4 python3 my_script.py +``` diff --git a/faster_whisper/__init__.py b/faster_whisper/__init__.py new file mode 100644 index 0000000..08b3009 --- /dev/null +++ b/faster_whisper/__init__.py @@ -0,0 +1 @@ +from faster_whisper.transcribe import WhisperModel diff --git a/faster_whisper/audio.py b/faster_whisper/audio.py new file mode 100644 index 0000000..eeb9f7f --- /dev/null +++ b/faster_whisper/audio.py @@ -0,0 +1,36 @@ +import av +import numpy as np + + +def decode_audio(input_file, sampling_rate=16000): + """Decodes the audio. + + Args: + input_file: Path to the input file or a file-like object. + sampling_rate: Resample the audio to this sample rate. + + Returns: + A float32 Numpy array. + """ + fifo = av.audio.fifo.AudioFifo() + resampler = av.audio.resampler.AudioResampler( + format="s16", + layout="mono", + rate=sampling_rate, + ) + + with av.open(input_file) as container: + # Decode and resample each audio frame. + for frame in container.decode(audio=0): + frame.pts = None + for new_frame in resampler.resample(frame): + fifo.write(new_frame) + + # Flush the resampler. + for new_frame in resampler.resample(None): + fifo.write(new_frame) + + frame = fifo.read() + + # Convert s16 back to f32. + return frame.to_ndarray().flatten().astype(np.float32) / 32768.0 diff --git a/faster_whisper/feature_extractor.py b/faster_whisper/feature_extractor.py new file mode 100644 index 0000000..6260009 --- /dev/null +++ b/faster_whisper/feature_extractor.py @@ -0,0 +1,163 @@ +import numpy as np + + +# Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/feature_extraction_whisper.py +class FeatureExtractor: + def __init__( + self, + feature_size=80, + sampling_rate=16000, + hop_length=160, + chunk_length=30, + n_fft=400, + ): + self.n_fft = n_fft + self.hop_length = hop_length + self.chunk_length = chunk_length + self.n_samples = chunk_length * sampling_rate + self.nb_max_frames = self.n_samples // hop_length + self.time_per_frame = hop_length / sampling_rate + self.sampling_rate = sampling_rate + self.mel_filters = self.get_mel_filters( + sampling_rate, n_fft, n_mels=feature_size + ) + + def get_mel_filters(self, sr, n_fft, n_mels=128, dtype=np.float32): + # Initialize the weights + n_mels = int(n_mels) + weights = np.zeros((n_mels, int(1 + n_fft // 2)), dtype=dtype) + + # Center freqs of each FFT bin + fftfreqs = np.fft.rfftfreq(n=n_fft, d=1.0 / sr) + + # 'Center freqs' of mel bands - uniformly spaced between limits + min_mel = 0.0 + max_mel = 45.245640471924965 + + mels = np.linspace(min_mel, max_mel, n_mels + 2) + + mels = np.asanyarray(mels) + + # Fill in the linear scale + f_min = 0.0 + f_sp = 200.0 / 3 + freqs = f_min + f_sp * mels + + # And now the nonlinear scale + min_log_hz = 1000.0 # beginning of log region (Hz) + min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels) + logstep = np.log(6.4) / 27.0 # step size for log region + + # If we have vector data, vectorize + log_t = mels >= min_log_mel + freqs[log_t] = min_log_hz * np.exp(logstep * (mels[log_t] - min_log_mel)) + + mel_f = freqs + + fdiff = np.diff(mel_f) + ramps = np.subtract.outer(mel_f, fftfreqs) + + for i in range(n_mels): + # lower and upper slopes for all bins + lower = -ramps[i] / fdiff[i] + upper = ramps[i + 2] / fdiff[i + 1] + + # .. then intersect them with each other and zero + weights[i] = np.maximum(0, np.minimum(lower, upper)) + + # Slaney-style mel is scaled to be approx constant energy per channel + enorm = 2.0 / (mel_f[2 : n_mels + 2] - mel_f[:n_mels]) + weights *= enorm[:, np.newaxis] + + return weights + + def fram_wave(self, waveform, center=True): + """ + Transform a raw waveform into a list of smaller waveforms. + The window length defines how much of the signal is + contain in each frame (smalle waveform), while the hope length defines the step + between the beginning of each new frame. + Centering is done by reflecting the waveform which is first centered around + `frame_idx * hop_length`. + """ + frames = [] + for i in range(0, waveform.shape[0] + 1, self.hop_length): + half_window = (self.n_fft - 1) // 2 + 1 + if center: + start = i - half_window if i > half_window else 0 + end = ( + i + half_window + if i < waveform.shape[0] - half_window + else waveform.shape[0] + ) + + frame = waveform[start:end] + + if start == 0: + padd_width = (-i + half_window, 0) + frame = np.pad(frame, pad_width=padd_width, mode="reflect") + + elif end == waveform.shape[0]: + padd_width = (0, (i - waveform.shape[0] + half_window)) + frame = np.pad(frame, pad_width=padd_width, mode="reflect") + + else: + frame = waveform[i : i + self.n_fft] + frame_width = frame.shape[0] + if frame_width < waveform.shape[0]: + frame = np.lib.pad( + frame, + pad_width=(0, self.n_fft - frame_width), + mode="constant", + constant_values=0, + ) + + frames.append(frame) + return np.stack(frames, 0) + + def stft(self, frames, window): + """ + Calculates the complex Short-Time Fourier Transform (STFT) of the given framed signal. + Should give the same results as `torch.stft`. + """ + frame_size = frames.shape[1] + fft_size = self.n_fft + + if fft_size is None: + fft_size = frame_size + + if fft_size < frame_size: + raise ValueError("FFT size must greater or equal the frame size") + # number of FFT bins to store + num_fft_bins = (fft_size >> 1) + 1 + + data = np.empty((len(frames), num_fft_bins), dtype=np.complex64) + fft_signal = np.zeros(fft_size) + + for f, frame in enumerate(frames): + if window is not None: + np.multiply(frame, window, out=fft_signal[:frame_size]) + else: + fft_signal[:frame_size] = frame + data[f] = np.fft.fft(fft_signal, axis=0)[:num_fft_bins] + return data.T + + def __call__(self, waveform): + """ + Compute the log-Mel spectrogram of the provided audio, gives similar results + whisper's original torch implementation with 1e-5 tolerance. + """ + window = np.hanning(self.n_fft + 1)[:-1] + + frames = self.fram_wave(waveform) + stft = self.stft(frames, window=window) + magnitudes = np.abs(stft[:, :-1]) ** 2 + + filters = self.mel_filters + mel_spec = filters @ magnitudes + + log_spec = np.log10(np.clip(mel_spec, a_min=1e-10, a_max=None)) + log_spec = np.maximum(log_spec, log_spec.max() - 8.0) + log_spec = (log_spec + 4.0) / 4.0 + + return log_spec diff --git a/faster_whisper/transcribe.py b/faster_whisper/transcribe.py new file mode 100644 index 0000000..a840253 --- /dev/null +++ b/faster_whisper/transcribe.py @@ -0,0 +1,342 @@ +import collections +import os +import zlib + +import ctranslate2 +import numpy as np +import tokenizers + +from faster_whisper.audio import decode_audio +from faster_whisper.feature_extractor import FeatureExtractor + + +class Segment(collections.namedtuple("Segment", ("start", "end", "text"))): + pass + + +class AudioInfo( + collections.namedtuple("AudioInfo", ("language", "language_probability")) +): + pass + + +class TranscriptionOptions( + collections.namedtuple( + "TranscriptionOptions", + ( + "beam_size", + "best_of", + "patience", + "log_prob_threshold", + "no_speech_threshold", + "compression_ratio_threshold", + "condition_on_previous_text", + "temperatures", + ), + ) +): + pass + + +class WhisperModel: + def __init__( + self, + model_path, + device="auto", + compute_type="default", + cpu_threads=0, + ): + """Initializes the Whisper model. + + Args: + model_path: Path to the converted model. + device: Device to use for computation ("cpu", "cuda", "auto"). + compute_type: Type to use for computation. + See https://opennmt.net/CTranslate2/quantization.html. + cpu_threads: Number of threads to use when running on CPU (4 by default). + On non zero value overrides the OMP_NUM_THREADS environment variable. + """ + self.model = ctranslate2.models.Whisper( + model_path, + device=device, + compute_type=compute_type, + intra_threads=cpu_threads, + ) + + self.feature_extractor = FeatureExtractor() + self.decoder = tokenizers.decoders.ByteLevel() + + with open(os.path.join(model_path, "vocabulary.txt")) as vocab_file: + self.ids_to_tokens = [line.rstrip("\n") for line in vocab_file] + self.tokens_to_ids = { + token: i for i, token in enumerate(self.ids_to_tokens) + } + + self.eot_id = self.tokens_to_ids["<|endoftext|>"] + self.timestamp_begin_id = self.tokens_to_ids["<|notimestamps|>"] + 1 + self.input_stride = 2 + self.time_precision = 0.02 + self.max_length = 448 + + def transcribe( + self, + input_file, + language=None, + beam_size=5, + best_of=5, + patience=1, + temperature=[0.0, 0.2, 0.4, 0.6, 0.8, 1.0], + compression_ratio_threshold=2.4, + log_prob_threshold=-1.0, + no_speech_threshold=0.6, + condition_on_previous_text=True, + ): + """Transcribes an input file. + + Arguments: + input_file: Path to the input file or a file-like object. + language: The language spoken in the audio. If not set, the language will be + detected in the first 30 seconds of audio. + beam_size: Beam size to use for decoding. + best_of: Number of candidates when sampling with non-zero temperature. + patience: Beam search patience factor. + temperature: Temperature for sampling. It can be a tuple of temperatures, + which will be successively used upon failures according to either + `compression_ratio_threshold` or `logprob_threshold`. + compression_ratio_threshold: If the gzip compression ratio is above this value, + treat as failed. + log_prob_threshold: If the average log probability over sampled tokens is + below this value, treat as failed. + no_speech_threshold: If the no_speech probability is higher than this value AND + the average log probability over sampled tokens is below `logprob_threshold`, + consider the segment as silent. + condition_on_previous_text: If True, the previous output of the model is provided + as a prompt for the next window; disabling may make the text inconsistent across + windows, but the model becomes less prone to getting stuck in a failure loop, + such as repetition looping or timestamps going out of sync. + + Returns: + A tuple with: + + - a generator over transcribed segments + - an instance of AudioInfo + """ + audio = decode_audio( + input_file, sampling_rate=self.feature_extractor.sampling_rate + ) + features = self.feature_extractor(audio) + + if language is None: + segment = self.get_segment(features) + input = self.get_input(segment) + results = self.model.detect_language(input) + language_token, language_probability = results[0][0] + language = language_token[2:-2] + else: + language_probability = 1 + + options = TranscriptionOptions( + beam_size=beam_size, + best_of=best_of, + patience=patience, + log_prob_threshold=log_prob_threshold, + no_speech_threshold=no_speech_threshold, + compression_ratio_threshold=compression_ratio_threshold, + condition_on_previous_text=condition_on_previous_text, + temperatures=( + temperature if isinstance(temperature, (list, tuple)) else [temperature] + ), + ) + + segments = self.generate_segments(features, language, options) + + audio_info = AudioInfo( + language=language, + language_probability=language_probability, + ) + + return segments, audio_info + + def generate_segments(self, features, language, options): + tokenized_segments = self.generate_tokenized_segments( + features, language, options + ) + + for start, end, tokens in tokenized_segments: + text = self.decode_text_tokens(tokens) + if not text.strip(): + continue + + yield Segment( + start=start, + end=end, + text=text, + ) + + def generate_tokenized_segments(self, features, language, options): + num_frames = features.shape[-1] + offset = 0 + all_tokens = [] + prompt_reset_since = 0 + + while offset < num_frames: + time_offset = offset * self.feature_extractor.time_per_frame + segment = self.get_segment(features, offset) + segment_duration = segment.shape[-1] * self.feature_extractor.time_per_frame + + previous_tokens = all_tokens[prompt_reset_since:] + prompt = self.get_prompt(language, previous_tokens) + result, temperature = self.generate_with_fallback(segment, prompt, options) + + if ( + result.no_speech_prob > options.no_speech_threshold + and result.scores[0] < options.log_prob_threshold + ): + offset += segment.shape[-1] + continue + + tokens = result.sequences_ids[0] + + consecutive_timestamps = [ + i + for i in range(len(tokens)) + if i > 0 + and tokens[i] >= self.timestamp_begin_id + and tokens[i - 1] >= self.timestamp_begin_id + ] + + if len(consecutive_timestamps) > 0: + last_slice = 0 + for i, current_slice in enumerate(consecutive_timestamps): + sliced_tokens = tokens[last_slice:current_slice] + start_timestamp_position = ( + sliced_tokens[0] - self.timestamp_begin_id + ) + end_timestamp_position = sliced_tokens[-1] - self.timestamp_begin_id + start_time = ( + time_offset + start_timestamp_position * self.time_precision + ) + end_time = ( + time_offset + end_timestamp_position * self.time_precision + ) + + last_in_window = i + 1 == len(consecutive_timestamps) + + # Include the last timestamp so that all tokens are included in a segment. + if last_in_window: + sliced_tokens.append(tokens[current_slice]) + + yield start_time, end_time, sliced_tokens + last_slice = current_slice + + last_timestamp_position = ( + tokens[last_slice - 1] - self.timestamp_begin_id + ) + offset += last_timestamp_position * self.input_stride + all_tokens.extend(tokens[: last_slice + 1]) + + else: + duration = segment_duration + timestamps = [ + token for token in tokens if token >= self.timestamp_begin_id + ] + if len(timestamps) > 0 and timestamps[-1] != self.timestamp_begin_id: + last_timestamp_position = timestamps[-1] - self.timestamp_begin_id + duration = last_timestamp_position * self.time_precision + + yield time_offset, time_offset + duration, tokens + + offset += segment.shape[-1] + all_tokens.extend(tokens) + + if not options.condition_on_previous_text or temperature > 0.5: + prompt_reset_since = len(all_tokens) + + def decode_text_tokens(self, tokens): + text_tokens = [ + self.ids_to_tokens[token] for token in tokens if token < self.eot_id + ] + + return self.decoder.decode(text_tokens) + + def generate_with_fallback(self, segment, prompt, options): + features = self.get_input(segment) + result = None + final_temperature = None + + for temperature in options.temperatures: + if temperature > 0: + kwargs = { + "beam_size": 1, + "num_hypotheses": options.best_of, + "sampling_topk": 0, + "sampling_temperature": temperature, + } + else: + kwargs = { + "beam_size": options.beam_size, + "patience": options.patience, + } + + final_temperature = temperature + result = self.model.generate( + features, + [prompt], + max_length=self.max_length, + return_scores=True, + return_no_speech_prob=True, + **kwargs, + )[0] + + tokens = result.sequences_ids[0] + text = self.decode_text_tokens(tokens) + compression_ratio = get_compression_ratio(text) + + if ( + compression_ratio <= options.compression_ratio_threshold + and result.scores[0] >= options.log_prob_threshold + ): + break + + return result, final_temperature + + def get_prompt(self, language, previous_tokens): + prompt = [] + + if previous_tokens: + prompt.append(self.tokens_to_ids["<|startofprev|>"]) + prompt.extend(previous_tokens[-(self.max_length // 2 - 1) :]) + + prompt += [ + self.tokens_to_ids["<|startoftranscript|>"], + self.tokens_to_ids["<|%s|>" % language], + self.tokens_to_ids["<|transcribe|>"], + ] + + return prompt + + def get_segment(self, features, offset=0): + if offset > 0: + features = features[:, offset:] + + num_frames = features.shape[-1] + required_num_frames = self.feature_extractor.nb_max_frames + + if num_frames > required_num_frames: + features = features[:, :required_num_frames] + elif num_frames < required_num_frames: + pad_widths = [(0, 0), (0, required_num_frames - num_frames)] + features = np.pad(features, pad_widths) + + features = np.ascontiguousarray(features) + return features + + def get_input(self, segment): + segment = np.expand_dims(segment, 0) + segment = ctranslate2.StorageView.from_array(segment) + return segment + + +def get_compression_ratio(text): + text_bytes = text.encode("utf-8") + return len(text_bytes) / len(zlib.compress(text_bytes)) diff --git a/requirements.conversion.txt b/requirements.conversion.txt new file mode 100644 index 0000000..56fdf5f --- /dev/null +++ b/requirements.conversion.txt @@ -0,0 +1 @@ +transformers[torch]>=4.23 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..9c39ab1 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,3 @@ +av==10.* +ctranslate2>=3.5,<4 +tokenizers==0.13.* diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..b33616d --- /dev/null +++ b/setup.py @@ -0,0 +1,28 @@ +import os + +from setuptools import find_packages, setup + + +def get_requirements(path): + with open(path, encoding="utf-8") as requirements: + return [requirement.strip() for requirement in requirements] + + +base_dir = os.path.dirname(os.path.abspath(__file__)) +install_requires = get_requirements(os.path.join(base_dir, "requirements.txt")) +conversion_requires = get_requirements( + os.path.join(base_dir, "requirements.conversion.txt") +) + +setup( + name="faster-whisper", + version="0.1.0", + description="Faster Whisper transcription with CTranslate2", + author="Guillaume Klein", + python_requires=">=3.7", + install_requires=install_requires, + extras_require={ + "conversion": conversion_requires, + }, + packages=find_packages(), +)