Return result with best log prob when all temperature fallbacks failed (#356)
* Resolve Inference Selection Bug * Refactor for better readability * Filter out results with compression_ratio * Refactor to avoid variable repetition * Fix incorrect index and perform minor refactoring * Remove final_temperature variable
This commit is contained in:
@@ -578,10 +578,9 @@ class WhisperModel:
|
|||||||
tokenizer: Tokenizer,
|
tokenizer: Tokenizer,
|
||||||
options: TranscriptionOptions,
|
options: TranscriptionOptions,
|
||||||
) -> Tuple[ctranslate2.models.WhisperGenerationResult, float, float, float]:
|
) -> Tuple[ctranslate2.models.WhisperGenerationResult, float, float, float]:
|
||||||
result = None
|
decode_result = None
|
||||||
avg_logprob = None
|
all_results = []
|
||||||
final_temperature = None
|
below_cr_threshold_results = []
|
||||||
compression_ratio = None
|
|
||||||
|
|
||||||
max_initial_timestamp_index = int(
|
max_initial_timestamp_index = int(
|
||||||
round(options.max_initial_timestamp / self.time_precision)
|
round(options.max_initial_timestamp / self.time_precision)
|
||||||
@@ -601,7 +600,6 @@ class WhisperModel:
|
|||||||
"patience": options.patience,
|
"patience": options.patience,
|
||||||
}
|
}
|
||||||
|
|
||||||
final_temperature = temperature
|
|
||||||
result = self.model.generate(
|
result = self.model.generate(
|
||||||
encoder_output,
|
encoder_output,
|
||||||
[prompt],
|
[prompt],
|
||||||
@@ -625,12 +623,18 @@ class WhisperModel:
|
|||||||
text = tokenizer.decode(tokens).strip()
|
text = tokenizer.decode(tokens).strip()
|
||||||
compression_ratio = get_compression_ratio(text)
|
compression_ratio = get_compression_ratio(text)
|
||||||
|
|
||||||
|
decode_result = (
|
||||||
|
result,
|
||||||
|
avg_logprob,
|
||||||
|
temperature,
|
||||||
|
compression_ratio,
|
||||||
|
)
|
||||||
|
all_results.append(decode_result)
|
||||||
|
|
||||||
needs_fallback = False
|
needs_fallback = False
|
||||||
|
|
||||||
if (
|
if options.compression_ratio_threshold is not None:
|
||||||
options.compression_ratio_threshold is not None
|
if compression_ratio > options.compression_ratio_threshold:
|
||||||
and compression_ratio > options.compression_ratio_threshold
|
|
||||||
):
|
|
||||||
needs_fallback = True # too repetitive
|
needs_fallback = True # too repetitive
|
||||||
|
|
||||||
self.logger.debug(
|
self.logger.debug(
|
||||||
@@ -639,6 +643,8 @@ class WhisperModel:
|
|||||||
compression_ratio,
|
compression_ratio,
|
||||||
options.compression_ratio_threshold,
|
options.compression_ratio_threshold,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
below_cr_threshold_results.append(decode_result)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
options.log_prob_threshold is not None
|
options.log_prob_threshold is not None
|
||||||
@@ -661,8 +667,13 @@ class WhisperModel:
|
|||||||
|
|
||||||
if not needs_fallback:
|
if not needs_fallback:
|
||||||
break
|
break
|
||||||
|
else:
|
||||||
|
# all failed, select the result with the highest average log probability
|
||||||
|
decode_result = max(
|
||||||
|
below_cr_threshold_results or all_results, key=lambda x: x[1]
|
||||||
|
)
|
||||||
|
|
||||||
return result, avg_logprob, final_temperature, compression_ratio
|
return decode_result
|
||||||
|
|
||||||
def get_prompt(
|
def get_prompt(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Reference in New Issue
Block a user