Return result with best log prob when all temperature fallbacks failed (#356)
* Resolve Inference Selection Bug * Refactor for better readability * Filter out results with compression_ratio * Refactor to avoid variable repetition * Fix incorrect index and perform minor refactoring * Remove final_temperature variable
This commit is contained in:
@@ -578,10 +578,9 @@ class WhisperModel:
|
||||
tokenizer: Tokenizer,
|
||||
options: TranscriptionOptions,
|
||||
) -> Tuple[ctranslate2.models.WhisperGenerationResult, float, float, float]:
|
||||
result = None
|
||||
avg_logprob = None
|
||||
final_temperature = None
|
||||
compression_ratio = None
|
||||
decode_result = None
|
||||
all_results = []
|
||||
below_cr_threshold_results = []
|
||||
|
||||
max_initial_timestamp_index = int(
|
||||
round(options.max_initial_timestamp / self.time_precision)
|
||||
@@ -601,7 +600,6 @@ class WhisperModel:
|
||||
"patience": options.patience,
|
||||
}
|
||||
|
||||
final_temperature = temperature
|
||||
result = self.model.generate(
|
||||
encoder_output,
|
||||
[prompt],
|
||||
@@ -625,12 +623,18 @@ class WhisperModel:
|
||||
text = tokenizer.decode(tokens).strip()
|
||||
compression_ratio = get_compression_ratio(text)
|
||||
|
||||
decode_result = (
|
||||
result,
|
||||
avg_logprob,
|
||||
temperature,
|
||||
compression_ratio,
|
||||
)
|
||||
all_results.append(decode_result)
|
||||
|
||||
needs_fallback = False
|
||||
|
||||
if (
|
||||
options.compression_ratio_threshold is not None
|
||||
and compression_ratio > options.compression_ratio_threshold
|
||||
):
|
||||
if options.compression_ratio_threshold is not None:
|
||||
if compression_ratio > options.compression_ratio_threshold:
|
||||
needs_fallback = True # too repetitive
|
||||
|
||||
self.logger.debug(
|
||||
@@ -639,6 +643,8 @@ class WhisperModel:
|
||||
compression_ratio,
|
||||
options.compression_ratio_threshold,
|
||||
)
|
||||
else:
|
||||
below_cr_threshold_results.append(decode_result)
|
||||
|
||||
if (
|
||||
options.log_prob_threshold is not None
|
||||
@@ -661,8 +667,13 @@ class WhisperModel:
|
||||
|
||||
if not needs_fallback:
|
||||
break
|
||||
else:
|
||||
# all failed, select the result with the highest average log probability
|
||||
decode_result = max(
|
||||
below_cr_threshold_results or all_results, key=lambda x: x[1]
|
||||
)
|
||||
|
||||
return result, avg_logprob, final_temperature, compression_ratio
|
||||
return decode_result
|
||||
|
||||
def get_prompt(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user