Suppress some special tokens when the default set is not used

This commit is contained in:
Guillaume Klein
2023-03-30 12:42:29 +02:00
parent eda840f8ff
commit 39fddba886
2 changed files with 34 additions and 1 deletions

View File

@@ -33,10 +33,22 @@ class Tokenizer:
self.language = None
self.language_code = "en"
@cached_property
def transcribe(self) -> int:
return self.tokenizer.token_to_id("<|transcribe|>")
@cached_property
def translate(self) -> int:
return self.tokenizer.token_to_id("<|translate|>")
@cached_property
def sot(self) -> int:
return self.tokenizer.token_to_id("<|startoftranscript|>")
@cached_property
def sot_lm(self) -> int:
return self.tokenizer.token_to_id("<|startoflm|>")
@cached_property
def sot_prev(self) -> int:
return self.tokenizer.token_to_id("<|startofprev|>")

View File

@@ -242,7 +242,7 @@ class WhisperModel:
initial_prompt=initial_prompt,
prefix=prefix,
suppress_blank=suppress_blank,
suppress_tokens=suppress_tokens,
suppress_tokens=get_suppressed_tokens(tokenizer, suppress_tokens),
without_timestamps=without_timestamps,
max_initial_timestamp=max_initial_timestamp,
word_timestamps=word_timestamps,
@@ -703,6 +703,27 @@ def get_compression_ratio(text: str) -> float:
return len(text_bytes) / len(zlib.compress(text_bytes))
def get_suppressed_tokens(tokenizer, suppress_tokens):
if not suppress_tokens or -1 in suppress_tokens:
return suppress_tokens
suppress_tokens = list(suppress_tokens)
# Ensure the following special tokens are suppressed when the user does
# not use the default set (-1).
suppress_tokens.extend(
[
tokenizer.transcribe,
tokenizer.translate,
tokenizer.sot,
tokenizer.sot_prev,
tokenizer.sot_lm,
]
)
return sorted(set(suppress_tokens))
def merge_punctuations(alignment: List[dict], prepended: str, appended: str):
# merge prepended punctuations
i = len(alignment) - 2