Files
faster-whisper/faster_whisper/transcribe.py
2023-02-11 11:47:07 +01:00

343 lines
12 KiB
Python

import collections
import os
import zlib
import ctranslate2
import numpy as np
import tokenizers
from faster_whisper.audio import decode_audio
from faster_whisper.feature_extractor import FeatureExtractor
class Segment(collections.namedtuple("Segment", ("start", "end", "text"))):
pass
class AudioInfo(
collections.namedtuple("AudioInfo", ("language", "language_probability"))
):
pass
class TranscriptionOptions(
collections.namedtuple(
"TranscriptionOptions",
(
"beam_size",
"best_of",
"patience",
"log_prob_threshold",
"no_speech_threshold",
"compression_ratio_threshold",
"condition_on_previous_text",
"temperatures",
),
)
):
pass
class WhisperModel:
def __init__(
self,
model_path,
device="auto",
compute_type="default",
cpu_threads=0,
):
"""Initializes the Whisper model.
Args:
model_path: Path to the converted model.
device: Device to use for computation ("cpu", "cuda", "auto").
compute_type: Type to use for computation.
See https://opennmt.net/CTranslate2/quantization.html.
cpu_threads: Number of threads to use when running on CPU (4 by default).
A non zero value overrides the OMP_NUM_THREADS environment variable.
"""
self.model = ctranslate2.models.Whisper(
model_path,
device=device,
compute_type=compute_type,
intra_threads=cpu_threads,
)
self.feature_extractor = FeatureExtractor()
self.decoder = tokenizers.decoders.ByteLevel()
with open(os.path.join(model_path, "vocabulary.txt")) as vocab_file:
self.ids_to_tokens = [line.rstrip("\n") for line in vocab_file]
self.tokens_to_ids = {
token: i for i, token in enumerate(self.ids_to_tokens)
}
self.eot_id = self.tokens_to_ids["<|endoftext|>"]
self.timestamp_begin_id = self.tokens_to_ids["<|notimestamps|>"] + 1
self.input_stride = 2
self.time_precision = 0.02
self.max_length = 448
def transcribe(
self,
input_file,
language=None,
beam_size=5,
best_of=5,
patience=1,
temperature=[0.0, 0.2, 0.4, 0.6, 0.8, 1.0],
compression_ratio_threshold=2.4,
log_prob_threshold=-1.0,
no_speech_threshold=0.6,
condition_on_previous_text=True,
):
"""Transcribes an input file.
Arguments:
input_file: Path to the input file or a file-like object.
language: The language spoken in the audio. If not set, the language will be
detected in the first 30 seconds of audio.
beam_size: Beam size to use for decoding.
best_of: Number of candidates when sampling with non-zero temperature.
patience: Beam search patience factor.
temperature: Temperature for sampling. It can be a tuple of temperatures,
which will be successively used upon failures according to either
`compression_ratio_threshold` or `logprob_threshold`.
compression_ratio_threshold: If the gzip compression ratio is above this value,
treat as failed.
log_prob_threshold: If the average log probability over sampled tokens is
below this value, treat as failed.
no_speech_threshold: If the no_speech probability is higher than this value AND
the average log probability over sampled tokens is below `logprob_threshold`,
consider the segment as silent.
condition_on_previous_text: If True, the previous output of the model is provided
as a prompt for the next window; disabling may make the text inconsistent across
windows, but the model becomes less prone to getting stuck in a failure loop,
such as repetition looping or timestamps going out of sync.
Returns:
A tuple with:
- a generator over transcribed segments
- an instance of AudioInfo
"""
audio = decode_audio(
input_file, sampling_rate=self.feature_extractor.sampling_rate
)
features = self.feature_extractor(audio)
if language is None:
segment = self.get_segment(features)
input = self.get_input(segment)
results = self.model.detect_language(input)
language_token, language_probability = results[0][0]
language = language_token[2:-2]
else:
language_probability = 1
options = TranscriptionOptions(
beam_size=beam_size,
best_of=best_of,
patience=patience,
log_prob_threshold=log_prob_threshold,
no_speech_threshold=no_speech_threshold,
compression_ratio_threshold=compression_ratio_threshold,
condition_on_previous_text=condition_on_previous_text,
temperatures=(
temperature if isinstance(temperature, (list, tuple)) else [temperature]
),
)
segments = self.generate_segments(features, language, options)
audio_info = AudioInfo(
language=language,
language_probability=language_probability,
)
return segments, audio_info
def generate_segments(self, features, language, options):
tokenized_segments = self.generate_tokenized_segments(
features, language, options
)
for start, end, tokens in tokenized_segments:
text = self.decode_text_tokens(tokens)
if not text.strip():
continue
yield Segment(
start=start,
end=end,
text=text,
)
def generate_tokenized_segments(self, features, language, options):
num_frames = features.shape[-1]
offset = 0
all_tokens = []
prompt_reset_since = 0
while offset < num_frames:
time_offset = offset * self.feature_extractor.time_per_frame
segment = self.get_segment(features, offset)
segment_duration = segment.shape[-1] * self.feature_extractor.time_per_frame
previous_tokens = all_tokens[prompt_reset_since:]
prompt = self.get_prompt(language, previous_tokens)
result, temperature = self.generate_with_fallback(segment, prompt, options)
if (
result.no_speech_prob > options.no_speech_threshold
and result.scores[0] < options.log_prob_threshold
):
offset += segment.shape[-1]
continue
tokens = result.sequences_ids[0]
consecutive_timestamps = [
i
for i in range(len(tokens))
if i > 0
and tokens[i] >= self.timestamp_begin_id
and tokens[i - 1] >= self.timestamp_begin_id
]
if len(consecutive_timestamps) > 0:
last_slice = 0
for i, current_slice in enumerate(consecutive_timestamps):
sliced_tokens = tokens[last_slice:current_slice]
start_timestamp_position = (
sliced_tokens[0] - self.timestamp_begin_id
)
end_timestamp_position = sliced_tokens[-1] - self.timestamp_begin_id
start_time = (
time_offset + start_timestamp_position * self.time_precision
)
end_time = (
time_offset + end_timestamp_position * self.time_precision
)
last_in_window = i + 1 == len(consecutive_timestamps)
# Include the last timestamp so that all tokens are included in a segment.
if last_in_window:
sliced_tokens.append(tokens[current_slice])
yield start_time, end_time, sliced_tokens
last_slice = current_slice
last_timestamp_position = (
tokens[last_slice - 1] - self.timestamp_begin_id
)
offset += last_timestamp_position * self.input_stride
all_tokens.extend(tokens[: last_slice + 1])
else:
duration = segment_duration
timestamps = [
token for token in tokens if token >= self.timestamp_begin_id
]
if len(timestamps) > 0 and timestamps[-1] != self.timestamp_begin_id:
last_timestamp_position = timestamps[-1] - self.timestamp_begin_id
duration = last_timestamp_position * self.time_precision
yield time_offset, time_offset + duration, tokens
offset += segment.shape[-1]
all_tokens.extend(tokens)
if not options.condition_on_previous_text or temperature > 0.5:
prompt_reset_since = len(all_tokens)
def decode_text_tokens(self, tokens):
text_tokens = [
self.ids_to_tokens[token] for token in tokens if token < self.eot_id
]
return self.decoder.decode(text_tokens)
def generate_with_fallback(self, segment, prompt, options):
features = self.get_input(segment)
result = None
final_temperature = None
for temperature in options.temperatures:
if temperature > 0:
kwargs = {
"beam_size": 1,
"num_hypotheses": options.best_of,
"sampling_topk": 0,
"sampling_temperature": temperature,
}
else:
kwargs = {
"beam_size": options.beam_size,
"patience": options.patience,
}
final_temperature = temperature
result = self.model.generate(
features,
[prompt],
max_length=self.max_length,
return_scores=True,
return_no_speech_prob=True,
**kwargs,
)[0]
tokens = result.sequences_ids[0]
text = self.decode_text_tokens(tokens)
compression_ratio = get_compression_ratio(text)
if (
compression_ratio <= options.compression_ratio_threshold
and result.scores[0] >= options.log_prob_threshold
):
break
return result, final_temperature
def get_prompt(self, language, previous_tokens):
prompt = []
if previous_tokens:
prompt.append(self.tokens_to_ids["<|startofprev|>"])
prompt.extend(previous_tokens[-(self.max_length // 2 - 1) :])
prompt += [
self.tokens_to_ids["<|startoftranscript|>"],
self.tokens_to_ids["<|%s|>" % language],
self.tokens_to_ids["<|transcribe|>"],
]
return prompt
def get_segment(self, features, offset=0):
if offset > 0:
features = features[:, offset:]
num_frames = features.shape[-1]
required_num_frames = self.feature_extractor.nb_max_frames
if num_frames > required_num_frames:
features = features[:, :required_num_frames]
elif num_frames < required_num_frames:
pad_widths = [(0, 0), (0, required_num_frames - num_frames)]
features = np.pad(features, pad_widths)
features = np.ascontiguousarray(features)
return features
def get_input(self, segment):
segment = np.expand_dims(segment, 0)
segment = ctranslate2.StorageView.from_array(segment)
return segment
def get_compression_ratio(text):
text_bytes = text.encode("utf-8")
return len(text_bytes) / len(zlib.compress(text_bytes))