From e47e00910a0c985be0f63bb7d6ba5f11fbaa9d39 Mon Sep 17 00:00:00 2001 From: Guillaume Klein Date: Wed, 22 Feb 2023 10:27:38 +0100 Subject: [PATCH] Add length_penalty parameter and correctly compute the avg log prob --- faster_whisper/transcribe.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/faster_whisper/transcribe.py b/faster_whisper/transcribe.py index 77c8943..1dd9d4a 100644 --- a/faster_whisper/transcribe.py +++ b/faster_whisper/transcribe.py @@ -27,6 +27,7 @@ class TranscriptionOptions( "beam_size", "best_of", "patience", + "length_penalty", "log_prob_threshold", "no_speech_threshold", "compression_ratio_threshold", @@ -95,6 +96,7 @@ class WhisperModel: beam_size=5, best_of=5, patience=1, + length_penalty=1, temperature=[0.0, 0.2, 0.4, 0.6, 0.8, 1.0], compression_ratio_threshold=2.4, log_prob_threshold=-1.0, @@ -114,6 +116,7 @@ class WhisperModel: beam_size: Beam size to use for decoding. best_of: Number of candidates when sampling with non-zero temperature. patience: Beam search patience factor. + length_penalty: Exponential length penalty constant. temperature: Temperature for sampling. It can be a tuple of temperatures, which will be successively used upon failures according to either `compression_ratio_threshold` or `logprob_threshold`. @@ -162,6 +165,7 @@ class WhisperModel: beam_size=beam_size, best_of=best_of, patience=patience, + length_penalty=length_penalty, log_prob_threshold=log_prob_threshold, no_speech_threshold=no_speech_threshold, compression_ratio_threshold=compression_ratio_threshold, @@ -224,11 +228,13 @@ class WhisperModel: without_timestamps=options.without_timestamps, ) - result, temperature = self.generate_with_fallback(segment, prompt, options) + result, avg_log_prob, temperature = self.generate_with_fallback( + segment, prompt, options + ) if ( result.no_speech_prob > options.no_speech_threshold - and result.scores[0] < options.log_prob_threshold + and avg_log_prob < options.log_prob_threshold ): offset += segment.shape[-1] continue @@ -297,6 +303,7 @@ class WhisperModel: def generate_with_fallback(self, segment, prompt, options): features = self.get_input(segment) result = None + avg_log_prob = None final_temperature = None max_length = min(self.max_length, 2 * (self.max_length - len(prompt))) @@ -318,23 +325,29 @@ class WhisperModel: result = self.model.generate( features, [prompt], + length_penalty=options.length_penalty, max_length=max_length, return_scores=True, return_no_speech_prob=True, **kwargs, )[0] + # Recover the average log prob from the returned score. + seq_len = len(result.sequences_ids[0]) + cum_log_prob = result.scores[0] * (seq_len**options.length_penalty) + avg_log_prob = cum_log_prob / (seq_len + 1) + tokens = result.sequences_ids[0] text = self.decode_text_tokens(tokens) compression_ratio = get_compression_ratio(text) if ( compression_ratio <= options.compression_ratio_threshold - and result.scores[0] >= options.log_prob_threshold + and avg_log_prob >= options.log_prob_threshold ): break - return result, final_temperature + return result, avg_log_prob, final_temperature def get_prompt( self,