suppress generating non-timestamp tokens at the beginning (#532)

This commit is contained in:
jumon
2022-11-16 04:44:36 +09:00
committed by GitHub
parent 9f70a352f9
commit 76148a56c5

View File

@@ -423,8 +423,12 @@ class ApplyTimestampRules(LogitFilter):
else: # cannot be normal text tokens else: # cannot be normal text tokens
logits[k, : self.tokenizer.eot] = -np.inf logits[k, : self.tokenizer.eot] = -np.inf
if tokens.shape[1] == self.sample_begin:
# suppress generating non-timestamp tokens at the beginning
logits[:, : self.tokenizer.timestamp_begin] = -np.inf
# apply the `max_initial_timestamp` option # apply the `max_initial_timestamp` option
if tokens.shape[1] == self.sample_begin and self.max_initial_timestamp_index is not None: if self.max_initial_timestamp_index is not None:
last_allowed = self.tokenizer.timestamp_begin + self.max_initial_timestamp_index last_allowed = self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
logits[:, last_allowed + 1 :] = -np.inf logits[:, last_allowed + 1 :] = -np.inf