Initial commit
This commit is contained in:
21
LICENSE
Normal file
21
LICENSE
Normal file
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2023 Guillaume Klein
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
63
README.md
Normal file
63
README.md
Normal file
@@ -0,0 +1,63 @@
|
||||
# Faster Whisper transcription with CTranslate2
|
||||
|
||||
This repository demonstrates how to implement the Whisper transcription using [CTranslate2](https://github.com/OpenNMT/CTranslate2/), which is a fast inference engine for Transformer models.
|
||||
|
||||
This implementation is about 4 times faster than [openai/whisper](https://github.com/openai/whisper) for the same accuracy while using less memory. The efficiency can be further improved with 8-bit quantization on both CPU and GPU.
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
pip install -e .[conversion]
|
||||
```
|
||||
|
||||
The model conversion requires the modules `transformers` and `torch` which are installed by the `[conversion]` requirement. Once a model is converted, these modules are no longer needed and the installation could be simplified to:
|
||||
|
||||
```bash
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### Model conversion
|
||||
|
||||
A Whisper model should be first converted into the CTranslate2 format. For example the command below converts the "medium" Whisper model and saves the weights in FP16:
|
||||
|
||||
```bash
|
||||
ct2-transformers-converter --model openai/whisper-medium --output_dir whisper-medium-ct2 --quantization float16
|
||||
```
|
||||
|
||||
If needed, models can also be converted from the code. See the [conversion API](https://opennmt.net/CTranslate2/python/ctranslate2.converters.TransformersConverter.html).
|
||||
|
||||
### Transcription
|
||||
|
||||
```python
|
||||
from faster_whisper import WhisperModel
|
||||
|
||||
model_path = "whisper-medium-ct2/"
|
||||
|
||||
# Run on GPU with FP16
|
||||
model = WhisperModel(model_path, device="cuda", compute_type="float16")
|
||||
|
||||
# or run on GPU with INT8
|
||||
# model = WhisperModel(model_path, device="cuda", compute_type="int8_float16")
|
||||
# or run on CPU with INT8
|
||||
# model = WhisperModel(model_path, device="cpu", compute_type="int8")
|
||||
|
||||
segments, info = model.transcribe("audio.mp3", beam_size=5)
|
||||
|
||||
print("Detected language '%s' with probability %f" % (info.language, info.language_probability))
|
||||
|
||||
for segment in segments:
|
||||
print("[%ds -> %ds] %s" % (segment.start, segment.end, segment.text))
|
||||
```
|
||||
|
||||
## Comparing performance against openai/whisper
|
||||
|
||||
If you are comparing the performance against [openai/whisper](https://github.com/openai/whisper), you should make sure to use the same settings in both frameworks. In particular:
|
||||
|
||||
* In openai/whisper, `model.transcribe` uses a beam size of 1 by default. A different beam size will have an important impact on performance so make to use the same.
|
||||
* When running on CPU, make sure to set the same number of threads. Both frameworks will read the environment variable `OMP_NUM_THREADS`, which can be set when running your script:
|
||||
|
||||
```bash
|
||||
OMP_NUM_THREADS=4 python3 my_script.py
|
||||
```
|
||||
1
faster_whisper/__init__.py
Normal file
1
faster_whisper/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from faster_whisper.transcribe import WhisperModel
|
||||
36
faster_whisper/audio.py
Normal file
36
faster_whisper/audio.py
Normal file
@@ -0,0 +1,36 @@
|
||||
import av
|
||||
import numpy as np
|
||||
|
||||
|
||||
def decode_audio(input_file, sampling_rate=16000):
|
||||
"""Decodes the audio.
|
||||
|
||||
Args:
|
||||
input_file: Path to the input file or a file-like object.
|
||||
sampling_rate: Resample the audio to this sample rate.
|
||||
|
||||
Returns:
|
||||
A float32 Numpy array.
|
||||
"""
|
||||
fifo = av.audio.fifo.AudioFifo()
|
||||
resampler = av.audio.resampler.AudioResampler(
|
||||
format="s16",
|
||||
layout="mono",
|
||||
rate=sampling_rate,
|
||||
)
|
||||
|
||||
with av.open(input_file) as container:
|
||||
# Decode and resample each audio frame.
|
||||
for frame in container.decode(audio=0):
|
||||
frame.pts = None
|
||||
for new_frame in resampler.resample(frame):
|
||||
fifo.write(new_frame)
|
||||
|
||||
# Flush the resampler.
|
||||
for new_frame in resampler.resample(None):
|
||||
fifo.write(new_frame)
|
||||
|
||||
frame = fifo.read()
|
||||
|
||||
# Convert s16 back to f32.
|
||||
return frame.to_ndarray().flatten().astype(np.float32) / 32768.0
|
||||
163
faster_whisper/feature_extractor.py
Normal file
163
faster_whisper/feature_extractor.py
Normal file
@@ -0,0 +1,163 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
# Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/feature_extraction_whisper.py
|
||||
class FeatureExtractor:
|
||||
def __init__(
|
||||
self,
|
||||
feature_size=80,
|
||||
sampling_rate=16000,
|
||||
hop_length=160,
|
||||
chunk_length=30,
|
||||
n_fft=400,
|
||||
):
|
||||
self.n_fft = n_fft
|
||||
self.hop_length = hop_length
|
||||
self.chunk_length = chunk_length
|
||||
self.n_samples = chunk_length * sampling_rate
|
||||
self.nb_max_frames = self.n_samples // hop_length
|
||||
self.time_per_frame = hop_length / sampling_rate
|
||||
self.sampling_rate = sampling_rate
|
||||
self.mel_filters = self.get_mel_filters(
|
||||
sampling_rate, n_fft, n_mels=feature_size
|
||||
)
|
||||
|
||||
def get_mel_filters(self, sr, n_fft, n_mels=128, dtype=np.float32):
|
||||
# Initialize the weights
|
||||
n_mels = int(n_mels)
|
||||
weights = np.zeros((n_mels, int(1 + n_fft // 2)), dtype=dtype)
|
||||
|
||||
# Center freqs of each FFT bin
|
||||
fftfreqs = np.fft.rfftfreq(n=n_fft, d=1.0 / sr)
|
||||
|
||||
# 'Center freqs' of mel bands - uniformly spaced between limits
|
||||
min_mel = 0.0
|
||||
max_mel = 45.245640471924965
|
||||
|
||||
mels = np.linspace(min_mel, max_mel, n_mels + 2)
|
||||
|
||||
mels = np.asanyarray(mels)
|
||||
|
||||
# Fill in the linear scale
|
||||
f_min = 0.0
|
||||
f_sp = 200.0 / 3
|
||||
freqs = f_min + f_sp * mels
|
||||
|
||||
# And now the nonlinear scale
|
||||
min_log_hz = 1000.0 # beginning of log region (Hz)
|
||||
min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
|
||||
logstep = np.log(6.4) / 27.0 # step size for log region
|
||||
|
||||
# If we have vector data, vectorize
|
||||
log_t = mels >= min_log_mel
|
||||
freqs[log_t] = min_log_hz * np.exp(logstep * (mels[log_t] - min_log_mel))
|
||||
|
||||
mel_f = freqs
|
||||
|
||||
fdiff = np.diff(mel_f)
|
||||
ramps = np.subtract.outer(mel_f, fftfreqs)
|
||||
|
||||
for i in range(n_mels):
|
||||
# lower and upper slopes for all bins
|
||||
lower = -ramps[i] / fdiff[i]
|
||||
upper = ramps[i + 2] / fdiff[i + 1]
|
||||
|
||||
# .. then intersect them with each other and zero
|
||||
weights[i] = np.maximum(0, np.minimum(lower, upper))
|
||||
|
||||
# Slaney-style mel is scaled to be approx constant energy per channel
|
||||
enorm = 2.0 / (mel_f[2 : n_mels + 2] - mel_f[:n_mels])
|
||||
weights *= enorm[:, np.newaxis]
|
||||
|
||||
return weights
|
||||
|
||||
def fram_wave(self, waveform, center=True):
|
||||
"""
|
||||
Transform a raw waveform into a list of smaller waveforms.
|
||||
The window length defines how much of the signal is
|
||||
contain in each frame (smalle waveform), while the hope length defines the step
|
||||
between the beginning of each new frame.
|
||||
Centering is done by reflecting the waveform which is first centered around
|
||||
`frame_idx * hop_length`.
|
||||
"""
|
||||
frames = []
|
||||
for i in range(0, waveform.shape[0] + 1, self.hop_length):
|
||||
half_window = (self.n_fft - 1) // 2 + 1
|
||||
if center:
|
||||
start = i - half_window if i > half_window else 0
|
||||
end = (
|
||||
i + half_window
|
||||
if i < waveform.shape[0] - half_window
|
||||
else waveform.shape[0]
|
||||
)
|
||||
|
||||
frame = waveform[start:end]
|
||||
|
||||
if start == 0:
|
||||
padd_width = (-i + half_window, 0)
|
||||
frame = np.pad(frame, pad_width=padd_width, mode="reflect")
|
||||
|
||||
elif end == waveform.shape[0]:
|
||||
padd_width = (0, (i - waveform.shape[0] + half_window))
|
||||
frame = np.pad(frame, pad_width=padd_width, mode="reflect")
|
||||
|
||||
else:
|
||||
frame = waveform[i : i + self.n_fft]
|
||||
frame_width = frame.shape[0]
|
||||
if frame_width < waveform.shape[0]:
|
||||
frame = np.lib.pad(
|
||||
frame,
|
||||
pad_width=(0, self.n_fft - frame_width),
|
||||
mode="constant",
|
||||
constant_values=0,
|
||||
)
|
||||
|
||||
frames.append(frame)
|
||||
return np.stack(frames, 0)
|
||||
|
||||
def stft(self, frames, window):
|
||||
"""
|
||||
Calculates the complex Short-Time Fourier Transform (STFT) of the given framed signal.
|
||||
Should give the same results as `torch.stft`.
|
||||
"""
|
||||
frame_size = frames.shape[1]
|
||||
fft_size = self.n_fft
|
||||
|
||||
if fft_size is None:
|
||||
fft_size = frame_size
|
||||
|
||||
if fft_size < frame_size:
|
||||
raise ValueError("FFT size must greater or equal the frame size")
|
||||
# number of FFT bins to store
|
||||
num_fft_bins = (fft_size >> 1) + 1
|
||||
|
||||
data = np.empty((len(frames), num_fft_bins), dtype=np.complex64)
|
||||
fft_signal = np.zeros(fft_size)
|
||||
|
||||
for f, frame in enumerate(frames):
|
||||
if window is not None:
|
||||
np.multiply(frame, window, out=fft_signal[:frame_size])
|
||||
else:
|
||||
fft_signal[:frame_size] = frame
|
||||
data[f] = np.fft.fft(fft_signal, axis=0)[:num_fft_bins]
|
||||
return data.T
|
||||
|
||||
def __call__(self, waveform):
|
||||
"""
|
||||
Compute the log-Mel spectrogram of the provided audio, gives similar results
|
||||
whisper's original torch implementation with 1e-5 tolerance.
|
||||
"""
|
||||
window = np.hanning(self.n_fft + 1)[:-1]
|
||||
|
||||
frames = self.fram_wave(waveform)
|
||||
stft = self.stft(frames, window=window)
|
||||
magnitudes = np.abs(stft[:, :-1]) ** 2
|
||||
|
||||
filters = self.mel_filters
|
||||
mel_spec = filters @ magnitudes
|
||||
|
||||
log_spec = np.log10(np.clip(mel_spec, a_min=1e-10, a_max=None))
|
||||
log_spec = np.maximum(log_spec, log_spec.max() - 8.0)
|
||||
log_spec = (log_spec + 4.0) / 4.0
|
||||
|
||||
return log_spec
|
||||
342
faster_whisper/transcribe.py
Normal file
342
faster_whisper/transcribe.py
Normal file
@@ -0,0 +1,342 @@
|
||||
import collections
|
||||
import os
|
||||
import zlib
|
||||
|
||||
import ctranslate2
|
||||
import numpy as np
|
||||
import tokenizers
|
||||
|
||||
from faster_whisper.audio import decode_audio
|
||||
from faster_whisper.feature_extractor import FeatureExtractor
|
||||
|
||||
|
||||
class Segment(collections.namedtuple("Segment", ("start", "end", "text"))):
|
||||
pass
|
||||
|
||||
|
||||
class AudioInfo(
|
||||
collections.namedtuple("AudioInfo", ("language", "language_probability"))
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class TranscriptionOptions(
|
||||
collections.namedtuple(
|
||||
"TranscriptionOptions",
|
||||
(
|
||||
"beam_size",
|
||||
"best_of",
|
||||
"patience",
|
||||
"log_prob_threshold",
|
||||
"no_speech_threshold",
|
||||
"compression_ratio_threshold",
|
||||
"condition_on_previous_text",
|
||||
"temperatures",
|
||||
),
|
||||
)
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class WhisperModel:
|
||||
def __init__(
|
||||
self,
|
||||
model_path,
|
||||
device="auto",
|
||||
compute_type="default",
|
||||
cpu_threads=0,
|
||||
):
|
||||
"""Initializes the Whisper model.
|
||||
|
||||
Args:
|
||||
model_path: Path to the converted model.
|
||||
device: Device to use for computation ("cpu", "cuda", "auto").
|
||||
compute_type: Type to use for computation.
|
||||
See https://opennmt.net/CTranslate2/quantization.html.
|
||||
cpu_threads: Number of threads to use when running on CPU (4 by default).
|
||||
On non zero value overrides the OMP_NUM_THREADS environment variable.
|
||||
"""
|
||||
self.model = ctranslate2.models.Whisper(
|
||||
model_path,
|
||||
device=device,
|
||||
compute_type=compute_type,
|
||||
intra_threads=cpu_threads,
|
||||
)
|
||||
|
||||
self.feature_extractor = FeatureExtractor()
|
||||
self.decoder = tokenizers.decoders.ByteLevel()
|
||||
|
||||
with open(os.path.join(model_path, "vocabulary.txt")) as vocab_file:
|
||||
self.ids_to_tokens = [line.rstrip("\n") for line in vocab_file]
|
||||
self.tokens_to_ids = {
|
||||
token: i for i, token in enumerate(self.ids_to_tokens)
|
||||
}
|
||||
|
||||
self.eot_id = self.tokens_to_ids["<|endoftext|>"]
|
||||
self.timestamp_begin_id = self.tokens_to_ids["<|notimestamps|>"] + 1
|
||||
self.input_stride = 2
|
||||
self.time_precision = 0.02
|
||||
self.max_length = 448
|
||||
|
||||
def transcribe(
|
||||
self,
|
||||
input_file,
|
||||
language=None,
|
||||
beam_size=5,
|
||||
best_of=5,
|
||||
patience=1,
|
||||
temperature=[0.0, 0.2, 0.4, 0.6, 0.8, 1.0],
|
||||
compression_ratio_threshold=2.4,
|
||||
log_prob_threshold=-1.0,
|
||||
no_speech_threshold=0.6,
|
||||
condition_on_previous_text=True,
|
||||
):
|
||||
"""Transcribes an input file.
|
||||
|
||||
Arguments:
|
||||
input_file: Path to the input file or a file-like object.
|
||||
language: The language spoken in the audio. If not set, the language will be
|
||||
detected in the first 30 seconds of audio.
|
||||
beam_size: Beam size to use for decoding.
|
||||
best_of: Number of candidates when sampling with non-zero temperature.
|
||||
patience: Beam search patience factor.
|
||||
temperature: Temperature for sampling. It can be a tuple of temperatures,
|
||||
which will be successively used upon failures according to either
|
||||
`compression_ratio_threshold` or `logprob_threshold`.
|
||||
compression_ratio_threshold: If the gzip compression ratio is above this value,
|
||||
treat as failed.
|
||||
log_prob_threshold: If the average log probability over sampled tokens is
|
||||
below this value, treat as failed.
|
||||
no_speech_threshold: If the no_speech probability is higher than this value AND
|
||||
the average log probability over sampled tokens is below `logprob_threshold`,
|
||||
consider the segment as silent.
|
||||
condition_on_previous_text: If True, the previous output of the model is provided
|
||||
as a prompt for the next window; disabling may make the text inconsistent across
|
||||
windows, but the model becomes less prone to getting stuck in a failure loop,
|
||||
such as repetition looping or timestamps going out of sync.
|
||||
|
||||
Returns:
|
||||
A tuple with:
|
||||
|
||||
- a generator over transcribed segments
|
||||
- an instance of AudioInfo
|
||||
"""
|
||||
audio = decode_audio(
|
||||
input_file, sampling_rate=self.feature_extractor.sampling_rate
|
||||
)
|
||||
features = self.feature_extractor(audio)
|
||||
|
||||
if language is None:
|
||||
segment = self.get_segment(features)
|
||||
input = self.get_input(segment)
|
||||
results = self.model.detect_language(input)
|
||||
language_token, language_probability = results[0][0]
|
||||
language = language_token[2:-2]
|
||||
else:
|
||||
language_probability = 1
|
||||
|
||||
options = TranscriptionOptions(
|
||||
beam_size=beam_size,
|
||||
best_of=best_of,
|
||||
patience=patience,
|
||||
log_prob_threshold=log_prob_threshold,
|
||||
no_speech_threshold=no_speech_threshold,
|
||||
compression_ratio_threshold=compression_ratio_threshold,
|
||||
condition_on_previous_text=condition_on_previous_text,
|
||||
temperatures=(
|
||||
temperature if isinstance(temperature, (list, tuple)) else [temperature]
|
||||
),
|
||||
)
|
||||
|
||||
segments = self.generate_segments(features, language, options)
|
||||
|
||||
audio_info = AudioInfo(
|
||||
language=language,
|
||||
language_probability=language_probability,
|
||||
)
|
||||
|
||||
return segments, audio_info
|
||||
|
||||
def generate_segments(self, features, language, options):
|
||||
tokenized_segments = self.generate_tokenized_segments(
|
||||
features, language, options
|
||||
)
|
||||
|
||||
for start, end, tokens in tokenized_segments:
|
||||
text = self.decode_text_tokens(tokens)
|
||||
if not text.strip():
|
||||
continue
|
||||
|
||||
yield Segment(
|
||||
start=start,
|
||||
end=end,
|
||||
text=text,
|
||||
)
|
||||
|
||||
def generate_tokenized_segments(self, features, language, options):
|
||||
num_frames = features.shape[-1]
|
||||
offset = 0
|
||||
all_tokens = []
|
||||
prompt_reset_since = 0
|
||||
|
||||
while offset < num_frames:
|
||||
time_offset = offset * self.feature_extractor.time_per_frame
|
||||
segment = self.get_segment(features, offset)
|
||||
segment_duration = segment.shape[-1] * self.feature_extractor.time_per_frame
|
||||
|
||||
previous_tokens = all_tokens[prompt_reset_since:]
|
||||
prompt = self.get_prompt(language, previous_tokens)
|
||||
result, temperature = self.generate_with_fallback(segment, prompt, options)
|
||||
|
||||
if (
|
||||
result.no_speech_prob > options.no_speech_threshold
|
||||
and result.scores[0] < options.log_prob_threshold
|
||||
):
|
||||
offset += segment.shape[-1]
|
||||
continue
|
||||
|
||||
tokens = result.sequences_ids[0]
|
||||
|
||||
consecutive_timestamps = [
|
||||
i
|
||||
for i in range(len(tokens))
|
||||
if i > 0
|
||||
and tokens[i] >= self.timestamp_begin_id
|
||||
and tokens[i - 1] >= self.timestamp_begin_id
|
||||
]
|
||||
|
||||
if len(consecutive_timestamps) > 0:
|
||||
last_slice = 0
|
||||
for i, current_slice in enumerate(consecutive_timestamps):
|
||||
sliced_tokens = tokens[last_slice:current_slice]
|
||||
start_timestamp_position = (
|
||||
sliced_tokens[0] - self.timestamp_begin_id
|
||||
)
|
||||
end_timestamp_position = sliced_tokens[-1] - self.timestamp_begin_id
|
||||
start_time = (
|
||||
time_offset + start_timestamp_position * self.time_precision
|
||||
)
|
||||
end_time = (
|
||||
time_offset + end_timestamp_position * self.time_precision
|
||||
)
|
||||
|
||||
last_in_window = i + 1 == len(consecutive_timestamps)
|
||||
|
||||
# Include the last timestamp so that all tokens are included in a segment.
|
||||
if last_in_window:
|
||||
sliced_tokens.append(tokens[current_slice])
|
||||
|
||||
yield start_time, end_time, sliced_tokens
|
||||
last_slice = current_slice
|
||||
|
||||
last_timestamp_position = (
|
||||
tokens[last_slice - 1] - self.timestamp_begin_id
|
||||
)
|
||||
offset += last_timestamp_position * self.input_stride
|
||||
all_tokens.extend(tokens[: last_slice + 1])
|
||||
|
||||
else:
|
||||
duration = segment_duration
|
||||
timestamps = [
|
||||
token for token in tokens if token >= self.timestamp_begin_id
|
||||
]
|
||||
if len(timestamps) > 0 and timestamps[-1] != self.timestamp_begin_id:
|
||||
last_timestamp_position = timestamps[-1] - self.timestamp_begin_id
|
||||
duration = last_timestamp_position * self.time_precision
|
||||
|
||||
yield time_offset, time_offset + duration, tokens
|
||||
|
||||
offset += segment.shape[-1]
|
||||
all_tokens.extend(tokens)
|
||||
|
||||
if not options.condition_on_previous_text or temperature > 0.5:
|
||||
prompt_reset_since = len(all_tokens)
|
||||
|
||||
def decode_text_tokens(self, tokens):
|
||||
text_tokens = [
|
||||
self.ids_to_tokens[token] for token in tokens if token < self.eot_id
|
||||
]
|
||||
|
||||
return self.decoder.decode(text_tokens)
|
||||
|
||||
def generate_with_fallback(self, segment, prompt, options):
|
||||
features = self.get_input(segment)
|
||||
result = None
|
||||
final_temperature = None
|
||||
|
||||
for temperature in options.temperatures:
|
||||
if temperature > 0:
|
||||
kwargs = {
|
||||
"beam_size": 1,
|
||||
"num_hypotheses": options.best_of,
|
||||
"sampling_topk": 0,
|
||||
"sampling_temperature": temperature,
|
||||
}
|
||||
else:
|
||||
kwargs = {
|
||||
"beam_size": options.beam_size,
|
||||
"patience": options.patience,
|
||||
}
|
||||
|
||||
final_temperature = temperature
|
||||
result = self.model.generate(
|
||||
features,
|
||||
[prompt],
|
||||
max_length=self.max_length,
|
||||
return_scores=True,
|
||||
return_no_speech_prob=True,
|
||||
**kwargs,
|
||||
)[0]
|
||||
|
||||
tokens = result.sequences_ids[0]
|
||||
text = self.decode_text_tokens(tokens)
|
||||
compression_ratio = get_compression_ratio(text)
|
||||
|
||||
if (
|
||||
compression_ratio <= options.compression_ratio_threshold
|
||||
and result.scores[0] >= options.log_prob_threshold
|
||||
):
|
||||
break
|
||||
|
||||
return result, final_temperature
|
||||
|
||||
def get_prompt(self, language, previous_tokens):
|
||||
prompt = []
|
||||
|
||||
if previous_tokens:
|
||||
prompt.append(self.tokens_to_ids["<|startofprev|>"])
|
||||
prompt.extend(previous_tokens[-(self.max_length // 2 - 1) :])
|
||||
|
||||
prompt += [
|
||||
self.tokens_to_ids["<|startoftranscript|>"],
|
||||
self.tokens_to_ids["<|%s|>" % language],
|
||||
self.tokens_to_ids["<|transcribe|>"],
|
||||
]
|
||||
|
||||
return prompt
|
||||
|
||||
def get_segment(self, features, offset=0):
|
||||
if offset > 0:
|
||||
features = features[:, offset:]
|
||||
|
||||
num_frames = features.shape[-1]
|
||||
required_num_frames = self.feature_extractor.nb_max_frames
|
||||
|
||||
if num_frames > required_num_frames:
|
||||
features = features[:, :required_num_frames]
|
||||
elif num_frames < required_num_frames:
|
||||
pad_widths = [(0, 0), (0, required_num_frames - num_frames)]
|
||||
features = np.pad(features, pad_widths)
|
||||
|
||||
features = np.ascontiguousarray(features)
|
||||
return features
|
||||
|
||||
def get_input(self, segment):
|
||||
segment = np.expand_dims(segment, 0)
|
||||
segment = ctranslate2.StorageView.from_array(segment)
|
||||
return segment
|
||||
|
||||
|
||||
def get_compression_ratio(text):
|
||||
text_bytes = text.encode("utf-8")
|
||||
return len(text_bytes) / len(zlib.compress(text_bytes))
|
||||
1
requirements.conversion.txt
Normal file
1
requirements.conversion.txt
Normal file
@@ -0,0 +1 @@
|
||||
transformers[torch]>=4.23
|
||||
3
requirements.txt
Normal file
3
requirements.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
av==10.*
|
||||
ctranslate2>=3.5,<4
|
||||
tokenizers==0.13.*
|
||||
28
setup.py
Normal file
28
setup.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import os
|
||||
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
|
||||
def get_requirements(path):
|
||||
with open(path, encoding="utf-8") as requirements:
|
||||
return [requirement.strip() for requirement in requirements]
|
||||
|
||||
|
||||
base_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
install_requires = get_requirements(os.path.join(base_dir, "requirements.txt"))
|
||||
conversion_requires = get_requirements(
|
||||
os.path.join(base_dir, "requirements.conversion.txt")
|
||||
)
|
||||
|
||||
setup(
|
||||
name="faster-whisper",
|
||||
version="0.1.0",
|
||||
description="Faster Whisper transcription with CTranslate2",
|
||||
author="Guillaume Klein",
|
||||
python_requires=">=3.7",
|
||||
install_requires=install_requires,
|
||||
extras_require={
|
||||
"conversion": conversion_requires,
|
||||
},
|
||||
packages=find_packages(),
|
||||
)
|
||||
Reference in New Issue
Block a user