Add length_penalty parameter and correctly compute the avg log prob

This commit is contained in:
Guillaume Klein
2023-02-22 10:27:38 +01:00
parent f5c9f15c2c
commit e47e00910a

View File

@@ -27,6 +27,7 @@ class TranscriptionOptions(
"beam_size", "beam_size",
"best_of", "best_of",
"patience", "patience",
"length_penalty",
"log_prob_threshold", "log_prob_threshold",
"no_speech_threshold", "no_speech_threshold",
"compression_ratio_threshold", "compression_ratio_threshold",
@@ -95,6 +96,7 @@ class WhisperModel:
beam_size=5, beam_size=5,
best_of=5, best_of=5,
patience=1, patience=1,
length_penalty=1,
temperature=[0.0, 0.2, 0.4, 0.6, 0.8, 1.0], temperature=[0.0, 0.2, 0.4, 0.6, 0.8, 1.0],
compression_ratio_threshold=2.4, compression_ratio_threshold=2.4,
log_prob_threshold=-1.0, log_prob_threshold=-1.0,
@@ -114,6 +116,7 @@ class WhisperModel:
beam_size: Beam size to use for decoding. beam_size: Beam size to use for decoding.
best_of: Number of candidates when sampling with non-zero temperature. best_of: Number of candidates when sampling with non-zero temperature.
patience: Beam search patience factor. patience: Beam search patience factor.
length_penalty: Exponential length penalty constant.
temperature: Temperature for sampling. It can be a tuple of temperatures, temperature: Temperature for sampling. It can be a tuple of temperatures,
which will be successively used upon failures according to either which will be successively used upon failures according to either
`compression_ratio_threshold` or `logprob_threshold`. `compression_ratio_threshold` or `logprob_threshold`.
@@ -162,6 +165,7 @@ class WhisperModel:
beam_size=beam_size, beam_size=beam_size,
best_of=best_of, best_of=best_of,
patience=patience, patience=patience,
length_penalty=length_penalty,
log_prob_threshold=log_prob_threshold, log_prob_threshold=log_prob_threshold,
no_speech_threshold=no_speech_threshold, no_speech_threshold=no_speech_threshold,
compression_ratio_threshold=compression_ratio_threshold, compression_ratio_threshold=compression_ratio_threshold,
@@ -224,11 +228,13 @@ class WhisperModel:
without_timestamps=options.without_timestamps, without_timestamps=options.without_timestamps,
) )
result, temperature = self.generate_with_fallback(segment, prompt, options) result, avg_log_prob, temperature = self.generate_with_fallback(
segment, prompt, options
)
if ( if (
result.no_speech_prob > options.no_speech_threshold result.no_speech_prob > options.no_speech_threshold
and result.scores[0] < options.log_prob_threshold and avg_log_prob < options.log_prob_threshold
): ):
offset += segment.shape[-1] offset += segment.shape[-1]
continue continue
@@ -297,6 +303,7 @@ class WhisperModel:
def generate_with_fallback(self, segment, prompt, options): def generate_with_fallback(self, segment, prompt, options):
features = self.get_input(segment) features = self.get_input(segment)
result = None result = None
avg_log_prob = None
final_temperature = None final_temperature = None
max_length = min(self.max_length, 2 * (self.max_length - len(prompt))) max_length = min(self.max_length, 2 * (self.max_length - len(prompt)))
@@ -318,23 +325,29 @@ class WhisperModel:
result = self.model.generate( result = self.model.generate(
features, features,
[prompt], [prompt],
length_penalty=options.length_penalty,
max_length=max_length, max_length=max_length,
return_scores=True, return_scores=True,
return_no_speech_prob=True, return_no_speech_prob=True,
**kwargs, **kwargs,
)[0] )[0]
# Recover the average log prob from the returned score.
seq_len = len(result.sequences_ids[0])
cum_log_prob = result.scores[0] * (seq_len**options.length_penalty)
avg_log_prob = cum_log_prob / (seq_len + 1)
tokens = result.sequences_ids[0] tokens = result.sequences_ids[0]
text = self.decode_text_tokens(tokens) text = self.decode_text_tokens(tokens)
compression_ratio = get_compression_ratio(text) compression_ratio = get_compression_ratio(text)
if ( if (
compression_ratio <= options.compression_ratio_threshold compression_ratio <= options.compression_ratio_threshold
and result.scores[0] >= options.log_prob_threshold and avg_log_prob >= options.log_prob_threshold
): ):
break break
return result, final_temperature return result, avg_log_prob, final_temperature
def get_prompt( def get_prompt(
self, self,