Add length_penalty parameter and correctly compute the avg log prob

This commit is contained in:
Guillaume Klein
2023-02-22 10:27:38 +01:00
parent f5c9f15c2c
commit e47e00910a

View File

@@ -27,6 +27,7 @@ class TranscriptionOptions(
"beam_size",
"best_of",
"patience",
"length_penalty",
"log_prob_threshold",
"no_speech_threshold",
"compression_ratio_threshold",
@@ -95,6 +96,7 @@ class WhisperModel:
beam_size=5,
best_of=5,
patience=1,
length_penalty=1,
temperature=[0.0, 0.2, 0.4, 0.6, 0.8, 1.0],
compression_ratio_threshold=2.4,
log_prob_threshold=-1.0,
@@ -114,6 +116,7 @@ class WhisperModel:
beam_size: Beam size to use for decoding.
best_of: Number of candidates when sampling with non-zero temperature.
patience: Beam search patience factor.
length_penalty: Exponential length penalty constant.
temperature: Temperature for sampling. It can be a tuple of temperatures,
which will be successively used upon failures according to either
`compression_ratio_threshold` or `logprob_threshold`.
@@ -162,6 +165,7 @@ class WhisperModel:
beam_size=beam_size,
best_of=best_of,
patience=patience,
length_penalty=length_penalty,
log_prob_threshold=log_prob_threshold,
no_speech_threshold=no_speech_threshold,
compression_ratio_threshold=compression_ratio_threshold,
@@ -224,11 +228,13 @@ class WhisperModel:
without_timestamps=options.without_timestamps,
)
result, temperature = self.generate_with_fallback(segment, prompt, options)
result, avg_log_prob, temperature = self.generate_with_fallback(
segment, prompt, options
)
if (
result.no_speech_prob > options.no_speech_threshold
and result.scores[0] < options.log_prob_threshold
and avg_log_prob < options.log_prob_threshold
):
offset += segment.shape[-1]
continue
@@ -297,6 +303,7 @@ class WhisperModel:
def generate_with_fallback(self, segment, prompt, options):
features = self.get_input(segment)
result = None
avg_log_prob = None
final_temperature = None
max_length = min(self.max_length, 2 * (self.max_length - len(prompt)))
@@ -318,23 +325,29 @@ class WhisperModel:
result = self.model.generate(
features,
[prompt],
length_penalty=options.length_penalty,
max_length=max_length,
return_scores=True,
return_no_speech_prob=True,
**kwargs,
)[0]
# Recover the average log prob from the returned score.
seq_len = len(result.sequences_ids[0])
cum_log_prob = result.scores[0] * (seq_len**options.length_penalty)
avg_log_prob = cum_log_prob / (seq_len + 1)
tokens = result.sequences_ids[0]
text = self.decode_text_tokens(tokens)
compression_ratio = get_compression_ratio(text)
if (
compression_ratio <= options.compression_ratio_threshold
and result.scores[0] >= options.log_prob_threshold
and avg_log_prob >= options.log_prob_threshold
):
break
return result, final_temperature
return result, avg_log_prob, final_temperature
def get_prompt(
self,