diff --git a/whisper/decoding.py b/whisper/decoding.py index 81cd845..2592ba9 100644 --- a/whisper/decoding.py +++ b/whisper/decoding.py @@ -471,6 +471,13 @@ class ApplyTimestampRules(LogitFilter): # timestamps shouldn't decrease; forbid timestamp tokens smaller than the last logits[k, self.tokenizer.timestamp_begin : timestamps[-1]] = -np.inf + # to force that timestamps are strictly increasing + if last_was_timestamp and not penultimate_was_timestamp: + timestamp_last = timestamps[-1] + else: + timestamp_last = timestamps[-1] + 1 + logits[k, self.tokenizer.timestamp_begin : timestamp_last] = -np.inf + if tokens.shape[1] == self.sample_begin: # suppress generating non-timestamp tokens at the beginning logits[:, : self.tokenizer.timestamp_begin] = -np.inf