From 1195359984770c708449d7eff3c523882f6b14ea Mon Sep 17 00:00:00 2001 From: Jordi Mas Date: Fri, 5 Jul 2024 09:43:11 +0200 Subject: [PATCH] Filter out non_speech_tokens in suppressed tokens (#898) * Filter out non_speech_tokens in suppressed tokens --- faster_whisper/tokenizer.py | 36 ++++++++++++ faster_whisper/transcribe.py | 26 +++++---- tests/test_transcribe.py | 108 +++++++++++++++++++++++++++++++++++ 3 files changed, 159 insertions(+), 11 deletions(-) diff --git a/faster_whisper/tokenizer.py b/faster_whisper/tokenizer.py index c3b13b4..3bf76a5 100644 --- a/faster_whisper/tokenizer.py +++ b/faster_whisper/tokenizer.py @@ -105,6 +105,42 @@ class Tokenizer: [s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs] ) + @cached_property + def non_speech_tokens(self) -> Tuple[int]: + """ + Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech + annotations, to prevent sampling texts that are not actually spoken in the audio, e.g. + + - ♪♪♪ + - ( SPEAKING FOREIGN LANGUAGE ) + - [DAVID] Hey there, + + keeping basic punctuations like commas, periods, question marks, exclamation points, etc. + """ + symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』') + symbols += ( + "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split() + ) + + # symbols that may be a single token or multiple tokens depending on the tokenizer. + # In case they're multiple tokens, suppress the first token, which is safe because: + # These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress + # in generations, and in the 3-byte UTF-8 representation they share the first two bytes. + miscellaneous = set("♩♪♫♬♭♮♯") + assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous) + + # allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word + result = {self.encode(" -")[0], self.encode(" '")[0]} + for symbol in symbols + list(miscellaneous): + for tokens in [ + self.encode(symbol), + self.encode(" " + symbol), + ]: + if len(tokens) == 1 or symbol in miscellaneous: + result.add(tokens[0]) + + return tuple(sorted(result)) + def split_to_word_tokens( self, tokens: List[int] ) -> Tuple[List[str], List[List[int]]]: diff --git a/faster_whisper/transcribe.py b/faster_whisper/transcribe.py index 9d603bd..382a366 100644 --- a/faster_whisper/transcribe.py +++ b/faster_whisper/transcribe.py @@ -277,7 +277,7 @@ class WhisperModel: prefix: Optional text to provide as a prefix for the first window. suppress_blank: Suppress blank outputs at the beginning of the sampling. suppress_tokens: List of token IDs to suppress. -1 will suppress a default set - of symbols as defined in the model config.json file. + of symbols as defined in `tokenizer.non_speech_tokens()` without_timestamps: Only sample text tokens. max_initial_timestamp: The initial timestamp cannot be later than this. word_timestamps: Extract word-level timestamps using the cross-attention pattern @@ -462,7 +462,11 @@ class WhisperModel: initial_prompt=initial_prompt, prefix=prefix, suppress_blank=suppress_blank, - suppress_tokens=get_suppressed_tokens(tokenizer, suppress_tokens), + suppress_tokens=( + get_suppressed_tokens(tokenizer, suppress_tokens) + if suppress_tokens + else suppress_tokens + ), without_timestamps=without_timestamps, max_initial_timestamp=max_initial_timestamp, word_timestamps=word_timestamps, @@ -488,7 +492,6 @@ class WhisperModel: vad_options=vad_parameters, all_language_probs=all_language_probs, ) - return segments, info def generate_segments( @@ -1227,15 +1230,16 @@ def get_compression_ratio(text: str) -> float: def get_suppressed_tokens( tokenizer: Tokenizer, - suppress_tokens: Optional[List[int]], + suppress_tokens: Tuple[int], ) -> Optional[List[int]]: - if not suppress_tokens or -1 in suppress_tokens: - return suppress_tokens + if -1 in suppress_tokens: + suppress_tokens = [t for t in suppress_tokens if t >= 0] + suppress_tokens.extend(tokenizer.non_speech_tokens) + elif suppress_tokens is None or len(suppress_tokens) == 0: + suppress_tokens = [] # interpret empty string as an empty list + else: + assert isinstance(suppress_tokens, list), "suppress_tokens must be a list" - suppress_tokens = list(suppress_tokens) - - # Ensure the following special tokens are suppressed when the user does - # not use the default set (-1). suppress_tokens.extend( [ tokenizer.transcribe, @@ -1246,7 +1250,7 @@ def get_suppressed_tokens( ] ) - return sorted(set(suppress_tokens)) + return tuple(sorted(set(suppress_tokens))) def merge_punctuations(alignment: List[dict], prepended: str, appended: str) -> None: diff --git a/tests/test_transcribe.py b/tests/test_transcribe.py index d30a0fb..7fa27b1 100644 --- a/tests/test_transcribe.py +++ b/tests/test_transcribe.py @@ -1,6 +1,8 @@ import os from faster_whisper import WhisperModel, decode_audio +from faster_whisper.tokenizer import Tokenizer +from faster_whisper.transcribe import get_suppressed_tokens def test_supported_languages(): @@ -97,3 +99,109 @@ def test_stereo_diarization(data_dir): segments, _ = model.transcribe(right) transcription = "".join(segment.text for segment in segments).strip() assert transcription == "The horizon seems extremely distant." + + +def test_suppressed_tokens_minus_1(): + model = WhisperModel("tiny.en") + + tokenizer = Tokenizer(model.hf_tokenizer, False) + tokens = get_suppressed_tokens(tokenizer, [-1]) + assert tokens == ( + 1, + 2, + 7, + 8, + 9, + 10, + 14, + 25, + 26, + 27, + 28, + 29, + 31, + 58, + 59, + 60, + 61, + 62, + 63, + 90, + 91, + 92, + 93, + 357, + 366, + 438, + 532, + 685, + 705, + 796, + 930, + 1058, + 1220, + 1267, + 1279, + 1303, + 1343, + 1377, + 1391, + 1635, + 1782, + 1875, + 2162, + 2361, + 2488, + 3467, + 4008, + 4211, + 4600, + 4808, + 5299, + 5855, + 6329, + 7203, + 9609, + 9959, + 10563, + 10786, + 11420, + 11709, + 11907, + 13163, + 13697, + 13700, + 14808, + 15306, + 16410, + 16791, + 17992, + 19203, + 19510, + 20724, + 22305, + 22935, + 27007, + 30109, + 30420, + 33409, + 34949, + 40283, + 40493, + 40549, + 47282, + 49146, + 50257, + 50357, + 50358, + 50359, + 50360, + ) + + +def test_suppressed_tokens_minus_value(): + model = WhisperModel("tiny.en") + + tokenizer = Tokenizer(model.hf_tokenizer, False) + tokens = get_suppressed_tokens(tokenizer, [13]) + assert tokens == (13, 50257, 50357, 50358, 50359, 50360)