Add length_penalty parameter and correctly compute the avg log prob
This commit is contained in:
@@ -27,6 +27,7 @@ class TranscriptionOptions(
|
||||
"beam_size",
|
||||
"best_of",
|
||||
"patience",
|
||||
"length_penalty",
|
||||
"log_prob_threshold",
|
||||
"no_speech_threshold",
|
||||
"compression_ratio_threshold",
|
||||
@@ -95,6 +96,7 @@ class WhisperModel:
|
||||
beam_size=5,
|
||||
best_of=5,
|
||||
patience=1,
|
||||
length_penalty=1,
|
||||
temperature=[0.0, 0.2, 0.4, 0.6, 0.8, 1.0],
|
||||
compression_ratio_threshold=2.4,
|
||||
log_prob_threshold=-1.0,
|
||||
@@ -114,6 +116,7 @@ class WhisperModel:
|
||||
beam_size: Beam size to use for decoding.
|
||||
best_of: Number of candidates when sampling with non-zero temperature.
|
||||
patience: Beam search patience factor.
|
||||
length_penalty: Exponential length penalty constant.
|
||||
temperature: Temperature for sampling. It can be a tuple of temperatures,
|
||||
which will be successively used upon failures according to either
|
||||
`compression_ratio_threshold` or `logprob_threshold`.
|
||||
@@ -162,6 +165,7 @@ class WhisperModel:
|
||||
beam_size=beam_size,
|
||||
best_of=best_of,
|
||||
patience=patience,
|
||||
length_penalty=length_penalty,
|
||||
log_prob_threshold=log_prob_threshold,
|
||||
no_speech_threshold=no_speech_threshold,
|
||||
compression_ratio_threshold=compression_ratio_threshold,
|
||||
@@ -224,11 +228,13 @@ class WhisperModel:
|
||||
without_timestamps=options.without_timestamps,
|
||||
)
|
||||
|
||||
result, temperature = self.generate_with_fallback(segment, prompt, options)
|
||||
result, avg_log_prob, temperature = self.generate_with_fallback(
|
||||
segment, prompt, options
|
||||
)
|
||||
|
||||
if (
|
||||
result.no_speech_prob > options.no_speech_threshold
|
||||
and result.scores[0] < options.log_prob_threshold
|
||||
and avg_log_prob < options.log_prob_threshold
|
||||
):
|
||||
offset += segment.shape[-1]
|
||||
continue
|
||||
@@ -297,6 +303,7 @@ class WhisperModel:
|
||||
def generate_with_fallback(self, segment, prompt, options):
|
||||
features = self.get_input(segment)
|
||||
result = None
|
||||
avg_log_prob = None
|
||||
final_temperature = None
|
||||
max_length = min(self.max_length, 2 * (self.max_length - len(prompt)))
|
||||
|
||||
@@ -318,23 +325,29 @@ class WhisperModel:
|
||||
result = self.model.generate(
|
||||
features,
|
||||
[prompt],
|
||||
length_penalty=options.length_penalty,
|
||||
max_length=max_length,
|
||||
return_scores=True,
|
||||
return_no_speech_prob=True,
|
||||
**kwargs,
|
||||
)[0]
|
||||
|
||||
# Recover the average log prob from the returned score.
|
||||
seq_len = len(result.sequences_ids[0])
|
||||
cum_log_prob = result.scores[0] * (seq_len**options.length_penalty)
|
||||
avg_log_prob = cum_log_prob / (seq_len + 1)
|
||||
|
||||
tokens = result.sequences_ids[0]
|
||||
text = self.decode_text_tokens(tokens)
|
||||
compression_ratio = get_compression_ratio(text)
|
||||
|
||||
if (
|
||||
compression_ratio <= options.compression_ratio_threshold
|
||||
and result.scores[0] >= options.log_prob_threshold
|
||||
and avg_log_prob >= options.log_prob_threshold
|
||||
):
|
||||
break
|
||||
|
||||
return result, final_temperature
|
||||
return result, avg_log_prob, final_temperature
|
||||
|
||||
def get_prompt(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user