added repetition_penalty to TranscriptionOptions (#403)

Co-authored-by: Aisu Wata <aisu.wata0@gmail.com>
This commit is contained in:
Aisu Wata
2023-08-06 05:08:24 -03:00
committed by GitHub
parent 1ce16652ee
commit 1562b02345

View File

@@ -47,6 +47,7 @@ class TranscriptionOptions(NamedTuple):
best_of: int
patience: float
length_penalty: float
repetition_penalty: float
log_prob_threshold: Optional[float]
no_speech_threshold: Optional[float]
compression_ratio_threshold: Optional[float]
@@ -160,6 +161,7 @@ class WhisperModel:
best_of: int = 5,
patience: float = 1,
length_penalty: float = 1,
repetition_penalty: float = 1,
temperature: Union[float, List[float], Tuple[float, ...]] = [
0.0,
0.2,
@@ -197,6 +199,8 @@ class WhisperModel:
best_of: Number of candidates when sampling with non-zero temperature.
patience: Beam search patience factor.
length_penalty: Exponential length penalty constant.
repetition_penalty: Penalty applied to the score of previously generated tokens
(set > 1 to penalize).
temperature: Temperature for sampling. It can be a tuple of temperatures,
which will be successively used upon failures according to either
`compression_ratio_threshold` or `log_prob_threshold`.
@@ -319,6 +323,7 @@ class WhisperModel:
best_of=best_of,
patience=patience,
length_penalty=length_penalty,
repetition_penalty=repetition_penalty,
log_prob_threshold=log_prob_threshold,
no_speech_threshold=no_speech_threshold,
compression_ratio_threshold=compression_ratio_threshold,
@@ -620,6 +625,7 @@ class WhisperModel:
encoder_output,
[prompt],
length_penalty=options.length_penalty,
repetition_penalty=options.repetition_penalty,
max_length=self.max_length,
return_scores=True,
return_no_speech_prob=True,