Merge remote-tracking branch 'upstream/master' into prompt

This commit is contained in:
2023-12-25 17:56:50 +08:00
10 changed files with 459 additions and 99 deletions

View File

@@ -1,8 +1,10 @@
import itertools
import json
import logging
import os
import zlib
from inspect import signature
from typing import BinaryIO, Iterable, List, NamedTuple, Optional, Tuple, Union
import ctranslate2
@@ -11,7 +13,7 @@ import tokenizers
from faster_whisper.audio import decode_audio
from faster_whisper.feature_extractor import FeatureExtractor
from faster_whisper.tokenizer import Tokenizer
from faster_whisper.tokenizer import _LANGUAGE_CODES, Tokenizer
from faster_whisper.utils import download_model, format_timestamp, get_logger
from faster_whisper.vad import (
SpeechTimestampsMap,
@@ -47,10 +49,13 @@ class TranscriptionOptions(NamedTuple):
best_of: int
patience: float
length_penalty: float
repetition_penalty: float
no_repeat_ngram_size: int
log_prob_threshold: Optional[float]
no_speech_threshold: Optional[float]
compression_ratio_threshold: Optional[float]
condition_on_previous_text: bool
prompt_reset_on_temperature: float
temperatures: List[float]
initial_prompt: Optional[Union[str, Iterable[int]]]
prefix: Optional[str]
@@ -67,6 +72,7 @@ class TranscriptionInfo(NamedTuple):
language: str
language_probability: float
duration: float
duration_after_vad: float
all_language_probs: Optional[List[Tuple[str, float]]]
transcription_options: TranscriptionOptions
vad_options: VadOptions
@@ -88,8 +94,9 @@ class WhisperModel:
Args:
model_size_or_path: Size of the model to use (tiny, tiny.en, base, base.en,
small, small.en, medium, medium.en, large-v1, or large-v2) or a path to a converted
model directory. When a size is configured, the converted model is downloaded
small, small.en, medium, medium.en, large-v1, large-v2, large-v3, or large), a path to a
converted model directory, or a CTranslate2-converted Whisper model ID from the HF Hub.
When a size or a model ID is configured, the converted model is downloaded
from the Hugging Face Hub.
device: Device to use for computation ("cpu", "cuda", "auto").
device_index: Device ID to use.
@@ -137,7 +144,8 @@ class WhisperModel:
"openai/whisper-tiny" + ("" if self.model.is_multilingual else ".en")
)
self.feature_extractor = FeatureExtractor()
self.feat_kwargs = self._get_feature_kwargs(model_path)
self.feature_extractor = FeatureExtractor(**self.feat_kwargs)
self.num_samples_per_token = self.feature_extractor.hop_length * 2
self.frames_per_second = (
self.feature_extractor.sampling_rate // self.feature_extractor.hop_length
@@ -149,6 +157,27 @@ class WhisperModel:
self.time_precision = 0.02
self.max_length = 448
@property
def supported_languages(self) -> List[str]:
"""The languages supported by the model."""
return list(_LANGUAGE_CODES) if self.model.is_multilingual else ["en"]
def _get_feature_kwargs(self, model_path) -> dict:
preprocessor_config_file = os.path.join(model_path, "preprocessor_config.json")
config = {}
if os.path.isfile(preprocessor_config_file):
try:
with open(preprocessor_config_file, "r", encoding="utf-8") as json_file:
config = json.load(json_file)
valid_keys = signature(FeatureExtractor.__init__).parameters.keys()
config = {k: v for k, v in config.items() if k in valid_keys}
except json.JSONDecodeError as e:
self.logger.warning(
"Could not load preprocessor_config.json: %s", str(e)
)
return config
def transcribe(
self,
audio: Union[str, BinaryIO, np.ndarray],
@@ -158,6 +187,8 @@ class WhisperModel:
best_of: int = 5,
patience: float = 1,
length_penalty: float = 1,
repetition_penalty: float = 1,
no_repeat_ngram_size: int = 0,
temperature: Union[float, List[float], Tuple[float, ...]] = [
0.0,
0.2,
@@ -170,6 +201,7 @@ class WhisperModel:
log_prob_threshold: Optional[float] = -1.0,
no_speech_threshold: Optional[float] = 0.6,
condition_on_previous_text: bool = True,
prompt_reset_on_temperature: float = 0.5,
initial_prompt: Optional[Union[str, Iterable[int]]] = None,
prefix: Optional[str] = None,
suppress_blank: bool = True,
@@ -194,6 +226,9 @@ class WhisperModel:
best_of: Number of candidates when sampling with non-zero temperature.
patience: Beam search patience factor.
length_penalty: Exponential length penalty constant.
repetition_penalty: Penalty applied to the score of previously generated tokens
(set > 1 to penalize).
no_repeat_ngram_size: Prevent repetitions of ngrams with this size (set 0 to disable).
temperature: Temperature for sampling. It can be a tuple of temperatures,
which will be successively used upon failures according to either
`compression_ratio_threshold` or `log_prob_threshold`.
@@ -208,6 +243,8 @@ class WhisperModel:
as a prompt for the next window; disabling may make the text inconsistent across
windows, but the model becomes less prone to getting stuck in a failure loop,
such as repetition looping or timestamps going out of sync.
prompt_reset_on_temperature: Resets prompt if temperature is above this value.
Arg has effect only if condition_on_previous_text is True.
initial_prompt: Optional text string or iterable of token ids to provide as a
prompt for the first window.
prefix: Optional text to provide as a prefix for the first window.
@@ -240,6 +277,7 @@ class WhisperModel:
audio = decode_audio(audio, sampling_rate=sampling_rate)
duration = audio.shape[0] / sampling_rate
duration_after_vad = duration
self.logger.info(
"Processing audio with duration %s", format_timestamp(duration)
@@ -252,10 +290,11 @@ class WhisperModel:
vad_parameters = VadOptions(**vad_parameters)
speech_chunks = get_speech_timestamps(audio, vad_parameters)
audio = collect_chunks(audio, speech_chunks)
duration_after_vad = audio.shape[0] / sampling_rate
self.logger.info(
"VAD filter removed %s of audio",
format_timestamp(duration - (audio.shape[0] / sampling_rate)),
format_timestamp(duration - duration_after_vad),
)
if self.logger.isEnabledFor(logging.DEBUG):
@@ -300,6 +339,13 @@ class WhisperModel:
language_probability,
)
else:
if not self.model.is_multilingual and language != "en":
self.logger.warning(
"The current model is English-only but the language parameter is set to '%s'; "
"using 'en' instead." % language
)
language = "en"
language_probability = 1
tokenizer = Tokenizer(
@@ -314,10 +360,13 @@ class WhisperModel:
best_of=best_of,
patience=patience,
length_penalty=length_penalty,
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
log_prob_threshold=log_prob_threshold,
no_speech_threshold=no_speech_threshold,
compression_ratio_threshold=compression_ratio_threshold,
condition_on_previous_text=condition_on_previous_text,
prompt_reset_on_temperature=prompt_reset_on_temperature,
temperatures=(
temperature if isinstance(temperature, (list, tuple)) else [temperature]
),
@@ -341,6 +390,7 @@ class WhisperModel:
language=language,
language_probability=language_probability,
duration=duration,
duration_after_vad=duration_after_vad,
transcription_options=options,
vad_options=vad_parameters,
all_language_probs=all_language_probs,
@@ -370,6 +420,7 @@ class WhisperModel:
else:
all_tokens.extend(options.initial_prompt)
last_speech_timestamp = 0.0
while seek < content_frames:
time_offset = seek * self.feature_extractor.time_per_frame
segment = features[:, seek : seek + self.feature_extractor.nb_max_frames]
@@ -391,7 +442,7 @@ class WhisperModel:
prefix=options.prefix if seek == 0 else None,
)
if encoder_output is None:
if seek > 0 or encoder_output is None:
encoder_output = self.encode(segment)
(
@@ -511,12 +562,14 @@ class WhisperModel:
segment_size,
options.prepend_punctuations,
options.append_punctuations,
last_speech_timestamp=last_speech_timestamp,
)
word_end_timestamps = [
w["end"] for s in current_segments for w in s["words"]
]
if len(word_end_timestamps) > 0:
last_speech_timestamp = word_end_timestamps[-1]
if not single_timestamp_ending and len(word_end_timestamps) > 0:
seek_shift = round(
(word_end_timestamps[-1] - time_offset) * self.frames_per_second
@@ -525,8 +578,6 @@ class WhisperModel:
if seek_shift > 0:
seek = previous_seek + seek_shift
encoder_output = None
for segment in current_segments:
tokens = segment["tokens"]
text = tokenizer.decode(tokens)
@@ -563,7 +614,17 @@ class WhisperModel:
),
)
if not options.condition_on_previous_text or temperature > 0.5:
if (
not options.condition_on_previous_text
or temperature > options.prompt_reset_on_temperature
):
if options.condition_on_previous_text:
self.logger.debug(
"Reset prompt. prompt_reset_on_temperature threshold is met %f > %f",
temperature,
options.prompt_reset_on_temperature,
)
prompt_reset_since = len(all_tokens)
def encode(self, features: np.ndarray) -> ctranslate2.StorageView:
@@ -583,10 +644,9 @@ class WhisperModel:
tokenizer: Tokenizer,
options: TranscriptionOptions,
) -> Tuple[ctranslate2.models.WhisperGenerationResult, float, float, float]:
result = None
avg_logprob = None
final_temperature = None
compression_ratio = None
decode_result = None
all_results = []
below_cr_threshold_results = []
max_initial_timestamp_index = int(
round(options.max_initial_timestamp / self.time_precision)
@@ -606,11 +666,12 @@ class WhisperModel:
"patience": options.patience,
}
final_temperature = temperature
result = self.model.generate(
encoder_output,
[prompt],
length_penalty=options.length_penalty,
repetition_penalty=options.repetition_penalty,
no_repeat_ngram_size=options.no_repeat_ngram_size,
max_length=self.max_length,
return_scores=True,
return_no_speech_prob=True,
@@ -630,20 +691,28 @@ class WhisperModel:
text = tokenizer.decode(tokens).strip()
compression_ratio = get_compression_ratio(text)
decode_result = (
result,
avg_logprob,
temperature,
compression_ratio,
)
all_results.append(decode_result)
needs_fallback = False
if (
options.compression_ratio_threshold is not None
and compression_ratio > options.compression_ratio_threshold
):
needs_fallback = True # too repetitive
if options.compression_ratio_threshold is not None:
if compression_ratio > options.compression_ratio_threshold:
needs_fallback = True # too repetitive
self.logger.debug(
"Compression ratio threshold is not met with temperature %.1f (%f > %f)",
temperature,
compression_ratio,
options.compression_ratio_threshold,
)
self.logger.debug(
"Compression ratio threshold is not met with temperature %.1f (%f > %f)",
temperature,
compression_ratio,
options.compression_ratio_threshold,
)
else:
below_cr_threshold_results.append(decode_result)
if (
options.log_prob_threshold is not None
@@ -658,10 +727,28 @@ class WhisperModel:
options.log_prob_threshold,
)
if (
options.no_speech_threshold is not None
and result.no_speech_prob > options.no_speech_threshold
):
needs_fallback = False # silence
if not needs_fallback:
break
else:
# all failed, select the result with the highest average log probability
decode_result = max(
below_cr_threshold_results or all_results, key=lambda x: x[1]
)
# to pass final temperature for prompt_reset_on_temperature
decode_result = (
decode_result[0],
decode_result[1],
temperature,
decode_result[3],
)
return result, avg_logprob, final_temperature, compression_ratio
return decode_result
def get_prompt(
self,
@@ -685,6 +772,8 @@ class WhisperModel:
prefix_tokens = tokenizer.encode(" " + prefix.strip())
if len(prefix_tokens) >= self.max_length // 2:
prefix_tokens = prefix_tokens[: self.max_length // 2 - 1]
if not without_timestamps:
prompt.append(tokenizer.timestamp_begin)
prompt.extend(prefix_tokens)
return prompt
@@ -697,7 +786,8 @@ class WhisperModel:
num_frames: int,
prepend_punctuations: str,
append_punctuations: str,
):
last_speech_timestamp: float,
) -> None:
if len(segments) == 0:
return
@@ -710,6 +800,24 @@ class WhisperModel:
alignment = self.find_alignment(
tokenizer, text_tokens, encoder_output, num_frames
)
word_durations = np.array([word["end"] - word["start"] for word in alignment])
word_durations = word_durations[word_durations.nonzero()]
median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0
max_duration = median_duration * 2
# hack: truncate long words at sentence boundaries.
# a better segmentation algorithm based on VAD should be able to replace this.
if len(word_durations) > 0:
sentence_end_marks = ".。!?"
# ensure words at sentence boundaries
# are not longer than twice the median word duration.
for i in range(1, len(alignment)):
if alignment[i]["end"] - alignment[i]["start"] > max_duration:
if alignment[i]["word"] in sentence_end_marks:
alignment[i]["end"] = alignment[i]["start"] + max_duration
elif alignment[i - 1]["word"] in sentence_end_marks:
alignment[i]["start"] = alignment[i]["end"] - max_duration
merge_punctuations(alignment, prepend_punctuations, append_punctuations)
time_offset = (
@@ -740,10 +848,51 @@ class WhisperModel:
saved_tokens += len(timing["tokens"])
word_index += 1
# hack: truncate long words at segment boundaries.
# a better segmentation algorithm based on VAD should be able to replace this.
if len(words) > 0:
# adjust the segment-level timestamps based on the word-level timestamps
segment["start"] = words[0]["start"]
segment["end"] = words[-1]["end"]
# ensure the first and second word after a pause is not longer than
# twice the median word duration.
if words[0]["end"] - last_speech_timestamp > median_duration * 4 and (
words[0]["end"] - words[0]["start"] > max_duration
or (
len(words) > 1
and words[1]["end"] - words[0]["start"] > max_duration * 2
)
):
if (
len(words) > 1
and words[1]["end"] - words[1]["start"] > max_duration
):
boundary = max(
words[1]["end"] / 2, words[1]["end"] - max_duration
)
words[0]["end"] = words[1]["start"] = boundary
words[0]["start"] = max(0, words[0]["end"] - max_duration)
# prefer the segment-level start timestamp if the first word is too long.
if (
segment["start"] < words[0]["end"]
and segment["start"] - 0.5 > words[0]["start"]
):
words[0]["start"] = max(
0, min(words[0]["end"] - median_duration, segment["start"])
)
else:
segment["start"] = words[0]["start"]
# prefer the segment-level end timestamp if the last word is too long.
if (
segment["end"] > words[-1]["start"]
and segment["end"] + 0.5 < words[-1]["end"]
):
words[-1]["end"] = max(
words[-1]["start"] + median_duration, segment["end"]
)
else:
segment["end"] = words[-1]["end"]
last_speech_timestamp = segment["end"]
segment["words"] = words
@@ -775,6 +924,13 @@ class WhisperModel:
words, word_tokens = tokenizer.split_to_word_tokens(
text_tokens + [tokenizer.eot]
)
if len(word_tokens) <= 1:
# return on eot only
# >>> np.pad([], (1, 0))
# array([0.])
# This results in crashes when we lookup jump_times with float, like
# IndexError: arrays used as indices must be of integer (or boolean) type
return []
word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0))
if len(word_boundaries) <= 1:
return []
@@ -788,22 +944,6 @@ class WhisperModel:
for i, j in zip(word_boundaries[:-1], word_boundaries[1:])
]
# hack: ensure the first and second word is not longer than twice the median word duration.
# a better segmentation algorithm based on VAD should be able to replace this.
word_durations = end_times - start_times
word_durations = word_durations[word_durations.nonzero()]
if len(word_durations) > 0:
median_duration = np.median(word_durations)
max_duration = median_duration * 2
if len(word_durations) >= 2 and word_durations[1] > max_duration:
boundary = max(end_times[2] / 2, end_times[2] - max_duration)
end_times[0] = start_times[1] = boundary
if (
len(word_durations) >= 1
and end_times[0] - start_times[0] > max_duration
):
start_times[0] = max(0, end_times[0] - max_duration)
return [
dict(
word=word, tokens=tokens, start=start, end=end, probability=probability
@@ -860,7 +1000,10 @@ def get_compression_ratio(text: str) -> float:
return len(text_bytes) / len(zlib.compress(text_bytes))
def get_suppressed_tokens(tokenizer, suppress_tokens):
def get_suppressed_tokens(
tokenizer: Tokenizer,
suppress_tokens: Optional[List[int]],
) -> Optional[List[int]]:
if not suppress_tokens or -1 in suppress_tokens:
return suppress_tokens
@@ -881,7 +1024,7 @@ def get_suppressed_tokens(tokenizer, suppress_tokens):
return sorted(set(suppress_tokens))
def merge_punctuations(alignment: List[dict], prepended: str, appended: str):
def merge_punctuations(alignment: List[dict], prepended: str, appended: str) -> None:
# merge prepended punctuations
i = len(alignment) - 2
j = len(alignment) - 1