suppress generating non-timestamp tokens at the beginning (#532)
This commit is contained in:
@@ -423,10 +423,14 @@ class ApplyTimestampRules(LogitFilter):
|
|||||||
else: # cannot be normal text tokens
|
else: # cannot be normal text tokens
|
||||||
logits[k, : self.tokenizer.eot] = -np.inf
|
logits[k, : self.tokenizer.eot] = -np.inf
|
||||||
|
|
||||||
# apply the `max_initial_timestamp` option
|
if tokens.shape[1] == self.sample_begin:
|
||||||
if tokens.shape[1] == self.sample_begin and self.max_initial_timestamp_index is not None:
|
# suppress generating non-timestamp tokens at the beginning
|
||||||
last_allowed = self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
|
logits[:, : self.tokenizer.timestamp_begin] = -np.inf
|
||||||
logits[:, last_allowed + 1 :] = -np.inf
|
|
||||||
|
# apply the `max_initial_timestamp` option
|
||||||
|
if self.max_initial_timestamp_index is not None:
|
||||||
|
last_allowed = self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
|
||||||
|
logits[:, last_allowed + 1 :] = -np.inf
|
||||||
|
|
||||||
# if sum of probability over timestamps is above any other token, sample timestamp
|
# if sum of probability over timestamps is above any other token, sample timestamp
|
||||||
logprobs = F.log_softmax(logits.float(), dim=-1)
|
logprobs = F.log_softmax(logits.float(), dim=-1)
|
||||||
|
|||||||
Reference in New Issue
Block a user