Return result with best log prob when all temperature fallbacks failed (#356)

* Resolve Inference Selection Bug

* Refactor for better readability

* Filter out results with compression_ratio

* Refactor to avoid variable repetition

* Fix incorrect index and perform minor refactoring

* Remove final_temperature variable
This commit is contained in:
KH
2023-07-20 23:13:11 +09:00
committed by GitHub
parent 687db319e0
commit e786e26f75

View File

@@ -578,10 +578,9 @@ class WhisperModel:
tokenizer: Tokenizer, tokenizer: Tokenizer,
options: TranscriptionOptions, options: TranscriptionOptions,
) -> Tuple[ctranslate2.models.WhisperGenerationResult, float, float, float]: ) -> Tuple[ctranslate2.models.WhisperGenerationResult, float, float, float]:
result = None decode_result = None
avg_logprob = None all_results = []
final_temperature = None below_cr_threshold_results = []
compression_ratio = None
max_initial_timestamp_index = int( max_initial_timestamp_index = int(
round(options.max_initial_timestamp / self.time_precision) round(options.max_initial_timestamp / self.time_precision)
@@ -601,7 +600,6 @@ class WhisperModel:
"patience": options.patience, "patience": options.patience,
} }
final_temperature = temperature
result = self.model.generate( result = self.model.generate(
encoder_output, encoder_output,
[prompt], [prompt],
@@ -625,12 +623,18 @@ class WhisperModel:
text = tokenizer.decode(tokens).strip() text = tokenizer.decode(tokens).strip()
compression_ratio = get_compression_ratio(text) compression_ratio = get_compression_ratio(text)
decode_result = (
result,
avg_logprob,
temperature,
compression_ratio,
)
all_results.append(decode_result)
needs_fallback = False needs_fallback = False
if ( if options.compression_ratio_threshold is not None:
options.compression_ratio_threshold is not None if compression_ratio > options.compression_ratio_threshold:
and compression_ratio > options.compression_ratio_threshold
):
needs_fallback = True # too repetitive needs_fallback = True # too repetitive
self.logger.debug( self.logger.debug(
@@ -639,6 +643,8 @@ class WhisperModel:
compression_ratio, compression_ratio,
options.compression_ratio_threshold, options.compression_ratio_threshold,
) )
else:
below_cr_threshold_results.append(decode_result)
if ( if (
options.log_prob_threshold is not None options.log_prob_threshold is not None
@@ -661,8 +667,13 @@ class WhisperModel:
if not needs_fallback: if not needs_fallback:
break break
else:
# all failed, select the result with the highest average log probability
decode_result = max(
below_cr_threshold_results or all_results, key=lambda x: x[1]
)
return result, avg_logprob, final_temperature, compression_ratio return decode_result
def get_prompt( def get_prompt(
self, self,