From 39fddba8864fb52520ca100d845b199ed7732e06 Mon Sep 17 00:00:00 2001 From: Guillaume Klein Date: Thu, 30 Mar 2023 12:42:29 +0200 Subject: [PATCH] Suppress some special tokens when the default set is not used --- faster_whisper/tokenizer.py | 12 ++++++++++++ faster_whisper/transcribe.py | 23 ++++++++++++++++++++++- 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/faster_whisper/tokenizer.py b/faster_whisper/tokenizer.py index db69cf1..efe22a3 100644 --- a/faster_whisper/tokenizer.py +++ b/faster_whisper/tokenizer.py @@ -33,10 +33,22 @@ class Tokenizer: self.language = None self.language_code = "en" + @cached_property + def transcribe(self) -> int: + return self.tokenizer.token_to_id("<|transcribe|>") + + @cached_property + def translate(self) -> int: + return self.tokenizer.token_to_id("<|translate|>") + @cached_property def sot(self) -> int: return self.tokenizer.token_to_id("<|startoftranscript|>") + @cached_property + def sot_lm(self) -> int: + return self.tokenizer.token_to_id("<|startoflm|>") + @cached_property def sot_prev(self) -> int: return self.tokenizer.token_to_id("<|startofprev|>") diff --git a/faster_whisper/transcribe.py b/faster_whisper/transcribe.py index 96d2e7c..d716cb3 100644 --- a/faster_whisper/transcribe.py +++ b/faster_whisper/transcribe.py @@ -242,7 +242,7 @@ class WhisperModel: initial_prompt=initial_prompt, prefix=prefix, suppress_blank=suppress_blank, - suppress_tokens=suppress_tokens, + suppress_tokens=get_suppressed_tokens(tokenizer, suppress_tokens), without_timestamps=without_timestamps, max_initial_timestamp=max_initial_timestamp, word_timestamps=word_timestamps, @@ -703,6 +703,27 @@ def get_compression_ratio(text: str) -> float: return len(text_bytes) / len(zlib.compress(text_bytes)) +def get_suppressed_tokens(tokenizer, suppress_tokens): + if not suppress_tokens or -1 in suppress_tokens: + return suppress_tokens + + suppress_tokens = list(suppress_tokens) + + # Ensure the following special tokens are suppressed when the user does + # not use the default set (-1). + suppress_tokens.extend( + [ + tokenizer.transcribe, + tokenizer.translate, + tokenizer.sot, + tokenizer.sot_prev, + tokenizer.sot_lm, + ] + ) + + return sorted(set(suppress_tokens)) + + def merge_punctuations(alignment: List[dict], prepended: str, appended: str): # merge prepended punctuations i = len(alignment) - 2