Expose generation parameter no_repeat_ngram_size (#449)

This commit is contained in:
Guillaume Klein
2023-09-01 17:31:30 +02:00
committed by GitHub
parent 5871858a5f
commit f0ff12965a

View File

@@ -48,6 +48,7 @@ class TranscriptionOptions(NamedTuple):
patience: float
length_penalty: float
repetition_penalty: float
no_repeat_ngram_size: int
log_prob_threshold: Optional[float]
no_speech_threshold: Optional[float]
compression_ratio_threshold: Optional[float]
@@ -163,6 +164,7 @@ class WhisperModel:
patience: float = 1,
length_penalty: float = 1,
repetition_penalty: float = 1,
no_repeat_ngram_size: int = 0,
temperature: Union[float, List[float], Tuple[float, ...]] = [
0.0,
0.2,
@@ -202,6 +204,7 @@ class WhisperModel:
length_penalty: Exponential length penalty constant.
repetition_penalty: Penalty applied to the score of previously generated tokens
(set > 1 to penalize).
no_repeat_ngram_size: Prevent repetitions of ngrams with this size (set 0 to disable).
temperature: Temperature for sampling. It can be a tuple of temperatures,
which will be successively used upon failures according to either
`compression_ratio_threshold` or `log_prob_threshold`.
@@ -327,6 +330,7 @@ class WhisperModel:
patience=patience,
length_penalty=length_penalty,
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
log_prob_threshold=log_prob_threshold,
no_speech_threshold=no_speech_threshold,
compression_ratio_threshold=compression_ratio_threshold,
@@ -630,6 +634,7 @@ class WhisperModel:
[prompt],
length_penalty=options.length_penalty,
repetition_penalty=options.repetition_penalty,
no_repeat_ngram_size=options.no_repeat_ngram_size,
max_length=self.max_length,
return_scores=True,
return_no_speech_prob=True,