Return result with best log prob when all temperature fallbacks failed (#356)

* Resolve Inference Selection Bug

* Refactor for better readability

* Filter out results with compression_ratio

* Refactor to avoid variable repetition

* Fix incorrect index and perform minor refactoring

* Remove final_temperature variable
This commit is contained in:
KH
2023-07-20 23:13:11 +09:00
committed by GitHub
parent 687db319e0
commit e786e26f75

View File

@@ -578,10 +578,9 @@ class WhisperModel:
tokenizer: Tokenizer,
options: TranscriptionOptions,
) -> Tuple[ctranslate2.models.WhisperGenerationResult, float, float, float]:
result = None
avg_logprob = None
final_temperature = None
compression_ratio = None
decode_result = None
all_results = []
below_cr_threshold_results = []
max_initial_timestamp_index = int(
round(options.max_initial_timestamp / self.time_precision)
@@ -601,7 +600,6 @@ class WhisperModel:
"patience": options.patience,
}
final_temperature = temperature
result = self.model.generate(
encoder_output,
[prompt],
@@ -625,20 +623,28 @@ class WhisperModel:
text = tokenizer.decode(tokens).strip()
compression_ratio = get_compression_ratio(text)
decode_result = (
result,
avg_logprob,
temperature,
compression_ratio,
)
all_results.append(decode_result)
needs_fallback = False
if (
options.compression_ratio_threshold is not None
and compression_ratio > options.compression_ratio_threshold
):
needs_fallback = True # too repetitive
if options.compression_ratio_threshold is not None:
if compression_ratio > options.compression_ratio_threshold:
needs_fallback = True # too repetitive
self.logger.debug(
"Compression ratio threshold is not met with temperature %.1f (%f > %f)",
temperature,
compression_ratio,
options.compression_ratio_threshold,
)
self.logger.debug(
"Compression ratio threshold is not met with temperature %.1f (%f > %f)",
temperature,
compression_ratio,
options.compression_ratio_threshold,
)
else:
below_cr_threshold_results.append(decode_result)
if (
options.log_prob_threshold is not None
@@ -661,8 +667,13 @@ class WhisperModel:
if not needs_fallback:
break
else:
# all failed, select the result with the highest average log probability
decode_result = max(
below_cr_threshold_results or all_results, key=lambda x: x[1]
)
return result, avg_logprob, final_temperature, compression_ratio
return decode_result
def get_prompt(
self,