Make threshold parameters optional

This commit is contained in:
Guillaume Klein
2023-02-27 11:32:03 +01:00
parent f0add58bdc
commit 528aa3e784

View File

@@ -107,9 +107,9 @@ class WhisperModel:
0.8,
1.0,
],
compression_ratio_threshold: float = 2.4,
log_prob_threshold: float = -1.0,
no_speech_threshold: float = 0.6,
compression_ratio_threshold: Optional[float] = 2.4,
log_prob_threshold: Optional[float] = -1.0,
no_speech_threshold: Optional[float] = 0.6,
condition_on_previous_text: bool = True,
initial_prompt: Optional[str] = None,
without_timestamps: bool = False,
@@ -241,12 +241,21 @@ class WhisperModel:
segment, prompt, options
)
if (
result.no_speech_prob > options.no_speech_threshold
and avg_log_prob < options.log_prob_threshold
):
offset += segment.shape[-1]
continue
if options.no_speech_threshold is not None:
# no voice activity check
should_skip = result.no_speech_prob > options.no_speech_threshold
if (
options.log_prob_threshold is not None
and avg_log_prob > options.log_prob_threshold
):
# don't skip if the logprob is high enough, despite the no_speech_prob
should_skip = False
if should_skip:
# fast-forward to the next segment boundary
offset += segment.shape[-1]
continue
tokens = result.sequences_ids[0]
@@ -350,10 +359,21 @@ class WhisperModel:
text = self.decode_text_tokens(tokens).strip()
compression_ratio = get_compression_ratio(text)
needs_fallback = False
if (
compression_ratio <= options.compression_ratio_threshold
and avg_log_prob >= options.log_prob_threshold
options.compression_ratio_threshold is not None
and compression_ratio > options.compression_ratio_threshold
):
needs_fallback = True # too repetitive
if (
options.log_prob_threshold is not None
and avg_log_prob < options.log_prob_threshold
):
needs_fallback = True # average log probability is too low
if not needs_fallback:
break
return result, avg_log_prob, final_temperature