Add length_penalty parameter and correctly compute the avg log prob
This commit is contained in:
@@ -27,6 +27,7 @@ class TranscriptionOptions(
|
|||||||
"beam_size",
|
"beam_size",
|
||||||
"best_of",
|
"best_of",
|
||||||
"patience",
|
"patience",
|
||||||
|
"length_penalty",
|
||||||
"log_prob_threshold",
|
"log_prob_threshold",
|
||||||
"no_speech_threshold",
|
"no_speech_threshold",
|
||||||
"compression_ratio_threshold",
|
"compression_ratio_threshold",
|
||||||
@@ -95,6 +96,7 @@ class WhisperModel:
|
|||||||
beam_size=5,
|
beam_size=5,
|
||||||
best_of=5,
|
best_of=5,
|
||||||
patience=1,
|
patience=1,
|
||||||
|
length_penalty=1,
|
||||||
temperature=[0.0, 0.2, 0.4, 0.6, 0.8, 1.0],
|
temperature=[0.0, 0.2, 0.4, 0.6, 0.8, 1.0],
|
||||||
compression_ratio_threshold=2.4,
|
compression_ratio_threshold=2.4,
|
||||||
log_prob_threshold=-1.0,
|
log_prob_threshold=-1.0,
|
||||||
@@ -114,6 +116,7 @@ class WhisperModel:
|
|||||||
beam_size: Beam size to use for decoding.
|
beam_size: Beam size to use for decoding.
|
||||||
best_of: Number of candidates when sampling with non-zero temperature.
|
best_of: Number of candidates when sampling with non-zero temperature.
|
||||||
patience: Beam search patience factor.
|
patience: Beam search patience factor.
|
||||||
|
length_penalty: Exponential length penalty constant.
|
||||||
temperature: Temperature for sampling. It can be a tuple of temperatures,
|
temperature: Temperature for sampling. It can be a tuple of temperatures,
|
||||||
which will be successively used upon failures according to either
|
which will be successively used upon failures according to either
|
||||||
`compression_ratio_threshold` or `logprob_threshold`.
|
`compression_ratio_threshold` or `logprob_threshold`.
|
||||||
@@ -162,6 +165,7 @@ class WhisperModel:
|
|||||||
beam_size=beam_size,
|
beam_size=beam_size,
|
||||||
best_of=best_of,
|
best_of=best_of,
|
||||||
patience=patience,
|
patience=patience,
|
||||||
|
length_penalty=length_penalty,
|
||||||
log_prob_threshold=log_prob_threshold,
|
log_prob_threshold=log_prob_threshold,
|
||||||
no_speech_threshold=no_speech_threshold,
|
no_speech_threshold=no_speech_threshold,
|
||||||
compression_ratio_threshold=compression_ratio_threshold,
|
compression_ratio_threshold=compression_ratio_threshold,
|
||||||
@@ -224,11 +228,13 @@ class WhisperModel:
|
|||||||
without_timestamps=options.without_timestamps,
|
without_timestamps=options.without_timestamps,
|
||||||
)
|
)
|
||||||
|
|
||||||
result, temperature = self.generate_with_fallback(segment, prompt, options)
|
result, avg_log_prob, temperature = self.generate_with_fallback(
|
||||||
|
segment, prompt, options
|
||||||
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
result.no_speech_prob > options.no_speech_threshold
|
result.no_speech_prob > options.no_speech_threshold
|
||||||
and result.scores[0] < options.log_prob_threshold
|
and avg_log_prob < options.log_prob_threshold
|
||||||
):
|
):
|
||||||
offset += segment.shape[-1]
|
offset += segment.shape[-1]
|
||||||
continue
|
continue
|
||||||
@@ -297,6 +303,7 @@ class WhisperModel:
|
|||||||
def generate_with_fallback(self, segment, prompt, options):
|
def generate_with_fallback(self, segment, prompt, options):
|
||||||
features = self.get_input(segment)
|
features = self.get_input(segment)
|
||||||
result = None
|
result = None
|
||||||
|
avg_log_prob = None
|
||||||
final_temperature = None
|
final_temperature = None
|
||||||
max_length = min(self.max_length, 2 * (self.max_length - len(prompt)))
|
max_length = min(self.max_length, 2 * (self.max_length - len(prompt)))
|
||||||
|
|
||||||
@@ -318,23 +325,29 @@ class WhisperModel:
|
|||||||
result = self.model.generate(
|
result = self.model.generate(
|
||||||
features,
|
features,
|
||||||
[prompt],
|
[prompt],
|
||||||
|
length_penalty=options.length_penalty,
|
||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
return_scores=True,
|
return_scores=True,
|
||||||
return_no_speech_prob=True,
|
return_no_speech_prob=True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
|
# Recover the average log prob from the returned score.
|
||||||
|
seq_len = len(result.sequences_ids[0])
|
||||||
|
cum_log_prob = result.scores[0] * (seq_len**options.length_penalty)
|
||||||
|
avg_log_prob = cum_log_prob / (seq_len + 1)
|
||||||
|
|
||||||
tokens = result.sequences_ids[0]
|
tokens = result.sequences_ids[0]
|
||||||
text = self.decode_text_tokens(tokens)
|
text = self.decode_text_tokens(tokens)
|
||||||
compression_ratio = get_compression_ratio(text)
|
compression_ratio = get_compression_ratio(text)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
compression_ratio <= options.compression_ratio_threshold
|
compression_ratio <= options.compression_ratio_threshold
|
||||||
and result.scores[0] >= options.log_prob_threshold
|
and avg_log_prob >= options.log_prob_threshold
|
||||||
):
|
):
|
||||||
break
|
break
|
||||||
|
|
||||||
return result, final_temperature
|
return result, avg_log_prob, final_temperature
|
||||||
|
|
||||||
def get_prompt(
|
def get_prompt(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Reference in New Issue
Block a user