Make threshold parameters optional

This commit is contained in:
Guillaume Klein
2023-02-27 11:32:03 +01:00
parent f0add58bdc
commit 528aa3e784

View File

@@ -107,9 +107,9 @@ class WhisperModel:
0.8, 0.8,
1.0, 1.0,
], ],
compression_ratio_threshold: float = 2.4, compression_ratio_threshold: Optional[float] = 2.4,
log_prob_threshold: float = -1.0, log_prob_threshold: Optional[float] = -1.0,
no_speech_threshold: float = 0.6, no_speech_threshold: Optional[float] = 0.6,
condition_on_previous_text: bool = True, condition_on_previous_text: bool = True,
initial_prompt: Optional[str] = None, initial_prompt: Optional[str] = None,
without_timestamps: bool = False, without_timestamps: bool = False,
@@ -241,12 +241,21 @@ class WhisperModel:
segment, prompt, options segment, prompt, options
) )
if ( if options.no_speech_threshold is not None:
result.no_speech_prob > options.no_speech_threshold # no voice activity check
and avg_log_prob < options.log_prob_threshold should_skip = result.no_speech_prob > options.no_speech_threshold
):
offset += segment.shape[-1] if (
continue options.log_prob_threshold is not None
and avg_log_prob > options.log_prob_threshold
):
# don't skip if the logprob is high enough, despite the no_speech_prob
should_skip = False
if should_skip:
# fast-forward to the next segment boundary
offset += segment.shape[-1]
continue
tokens = result.sequences_ids[0] tokens = result.sequences_ids[0]
@@ -350,10 +359,21 @@ class WhisperModel:
text = self.decode_text_tokens(tokens).strip() text = self.decode_text_tokens(tokens).strip()
compression_ratio = get_compression_ratio(text) compression_ratio = get_compression_ratio(text)
needs_fallback = False
if ( if (
compression_ratio <= options.compression_ratio_threshold options.compression_ratio_threshold is not None
and avg_log_prob >= options.log_prob_threshold and compression_ratio > options.compression_ratio_threshold
): ):
needs_fallback = True # too repetitive
if (
options.log_prob_threshold is not None
and avg_log_prob < options.log_prob_threshold
):
needs_fallback = True # average log probability is too low
if not needs_fallback:
break break
return result, avg_log_prob, final_temperature return result, avg_log_prob, final_temperature