added repetition_penalty to TranscriptionOptions (#403)

Co-authored-by: Aisu Wata <aisu.wata0@gmail.com>
This commit is contained in:
Aisu Wata
2023-08-06 05:08:24 -03:00
committed by GitHub
parent 1ce16652ee
commit 1562b02345

View File

@@ -47,6 +47,7 @@ class TranscriptionOptions(NamedTuple):
best_of: int best_of: int
patience: float patience: float
length_penalty: float length_penalty: float
repetition_penalty: float
log_prob_threshold: Optional[float] log_prob_threshold: Optional[float]
no_speech_threshold: Optional[float] no_speech_threshold: Optional[float]
compression_ratio_threshold: Optional[float] compression_ratio_threshold: Optional[float]
@@ -160,6 +161,7 @@ class WhisperModel:
best_of: int = 5, best_of: int = 5,
patience: float = 1, patience: float = 1,
length_penalty: float = 1, length_penalty: float = 1,
repetition_penalty: float = 1,
temperature: Union[float, List[float], Tuple[float, ...]] = [ temperature: Union[float, List[float], Tuple[float, ...]] = [
0.0, 0.0,
0.2, 0.2,
@@ -197,6 +199,8 @@ class WhisperModel:
best_of: Number of candidates when sampling with non-zero temperature. best_of: Number of candidates when sampling with non-zero temperature.
patience: Beam search patience factor. patience: Beam search patience factor.
length_penalty: Exponential length penalty constant. length_penalty: Exponential length penalty constant.
repetition_penalty: Penalty applied to the score of previously generated tokens
(set > 1 to penalize).
temperature: Temperature for sampling. It can be a tuple of temperatures, temperature: Temperature for sampling. It can be a tuple of temperatures,
which will be successively used upon failures according to either which will be successively used upon failures according to either
`compression_ratio_threshold` or `log_prob_threshold`. `compression_ratio_threshold` or `log_prob_threshold`.
@@ -319,6 +323,7 @@ class WhisperModel:
best_of=best_of, best_of=best_of,
patience=patience, patience=patience,
length_penalty=length_penalty, length_penalty=length_penalty,
repetition_penalty=repetition_penalty,
log_prob_threshold=log_prob_threshold, log_prob_threshold=log_prob_threshold,
no_speech_threshold=no_speech_threshold, no_speech_threshold=no_speech_threshold,
compression_ratio_threshold=compression_ratio_threshold, compression_ratio_threshold=compression_ratio_threshold,
@@ -620,6 +625,7 @@ class WhisperModel:
encoder_output, encoder_output,
[prompt], [prompt],
length_penalty=options.length_penalty, length_penalty=options.length_penalty,
repetition_penalty=options.repetition_penalty,
max_length=self.max_length, max_length=self.max_length,
return_scores=True, return_scores=True,
return_no_speech_prob=True, return_no_speech_prob=True,