Filter out non_speech_tokens in suppressed tokens (#898)

* Filter out non_speech_tokens in suppressed tokens
This commit is contained in:
Jordi Mas
2024-07-05 09:43:11 +02:00
committed by GitHub
parent c22db5125d
commit 1195359984
3 changed files with 159 additions and 11 deletions

View File

@@ -105,6 +105,42 @@ class Tokenizer:
[s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs] [s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs]
) )
@cached_property
def non_speech_tokens(self) -> Tuple[int]:
"""
Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
annotations, to prevent sampling texts that are not actually spoken in the audio, e.g.
- ♪♪♪
- ( SPEAKING FOREIGN LANGUAGE )
- [DAVID] Hey there,
keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
"""
symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』')
symbols += (
"<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
)
# symbols that may be a single token or multiple tokens depending on the tokenizer.
# In case they're multiple tokens, suppress the first token, which is safe because:
# These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress
# in generations, and in the 3-byte UTF-8 representation they share the first two bytes.
miscellaneous = set("♩♪♫♬♭♮♯")
assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
# allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
result = {self.encode(" -")[0], self.encode(" '")[0]}
for symbol in symbols + list(miscellaneous):
for tokens in [
self.encode(symbol),
self.encode(" " + symbol),
]:
if len(tokens) == 1 or symbol in miscellaneous:
result.add(tokens[0])
return tuple(sorted(result))
def split_to_word_tokens( def split_to_word_tokens(
self, tokens: List[int] self, tokens: List[int]
) -> Tuple[List[str], List[List[int]]]: ) -> Tuple[List[str], List[List[int]]]:

View File

@@ -277,7 +277,7 @@ class WhisperModel:
prefix: Optional text to provide as a prefix for the first window. prefix: Optional text to provide as a prefix for the first window.
suppress_blank: Suppress blank outputs at the beginning of the sampling. suppress_blank: Suppress blank outputs at the beginning of the sampling.
suppress_tokens: List of token IDs to suppress. -1 will suppress a default set suppress_tokens: List of token IDs to suppress. -1 will suppress a default set
of symbols as defined in the model config.json file. of symbols as defined in `tokenizer.non_speech_tokens()`
without_timestamps: Only sample text tokens. without_timestamps: Only sample text tokens.
max_initial_timestamp: The initial timestamp cannot be later than this. max_initial_timestamp: The initial timestamp cannot be later than this.
word_timestamps: Extract word-level timestamps using the cross-attention pattern word_timestamps: Extract word-level timestamps using the cross-attention pattern
@@ -462,7 +462,11 @@ class WhisperModel:
initial_prompt=initial_prompt, initial_prompt=initial_prompt,
prefix=prefix, prefix=prefix,
suppress_blank=suppress_blank, suppress_blank=suppress_blank,
suppress_tokens=get_suppressed_tokens(tokenizer, suppress_tokens), suppress_tokens=(
get_suppressed_tokens(tokenizer, suppress_tokens)
if suppress_tokens
else suppress_tokens
),
without_timestamps=without_timestamps, without_timestamps=without_timestamps,
max_initial_timestamp=max_initial_timestamp, max_initial_timestamp=max_initial_timestamp,
word_timestamps=word_timestamps, word_timestamps=word_timestamps,
@@ -488,7 +492,6 @@ class WhisperModel:
vad_options=vad_parameters, vad_options=vad_parameters,
all_language_probs=all_language_probs, all_language_probs=all_language_probs,
) )
return segments, info return segments, info
def generate_segments( def generate_segments(
@@ -1227,15 +1230,16 @@ def get_compression_ratio(text: str) -> float:
def get_suppressed_tokens( def get_suppressed_tokens(
tokenizer: Tokenizer, tokenizer: Tokenizer,
suppress_tokens: Optional[List[int]], suppress_tokens: Tuple[int],
) -> Optional[List[int]]: ) -> Optional[List[int]]:
if not suppress_tokens or -1 in suppress_tokens: if -1 in suppress_tokens:
return suppress_tokens suppress_tokens = [t for t in suppress_tokens if t >= 0]
suppress_tokens.extend(tokenizer.non_speech_tokens)
elif suppress_tokens is None or len(suppress_tokens) == 0:
suppress_tokens = [] # interpret empty string as an empty list
else:
assert isinstance(suppress_tokens, list), "suppress_tokens must be a list"
suppress_tokens = list(suppress_tokens)
# Ensure the following special tokens are suppressed when the user does
# not use the default set (-1).
suppress_tokens.extend( suppress_tokens.extend(
[ [
tokenizer.transcribe, tokenizer.transcribe,
@@ -1246,7 +1250,7 @@ def get_suppressed_tokens(
] ]
) )
return sorted(set(suppress_tokens)) return tuple(sorted(set(suppress_tokens)))
def merge_punctuations(alignment: List[dict], prepended: str, appended: str) -> None: def merge_punctuations(alignment: List[dict], prepended: str, appended: str) -> None:

View File

@@ -1,6 +1,8 @@
import os import os
from faster_whisper import WhisperModel, decode_audio from faster_whisper import WhisperModel, decode_audio
from faster_whisper.tokenizer import Tokenizer
from faster_whisper.transcribe import get_suppressed_tokens
def test_supported_languages(): def test_supported_languages():
@@ -97,3 +99,109 @@ def test_stereo_diarization(data_dir):
segments, _ = model.transcribe(right) segments, _ = model.transcribe(right)
transcription = "".join(segment.text for segment in segments).strip() transcription = "".join(segment.text for segment in segments).strip()
assert transcription == "The horizon seems extremely distant." assert transcription == "The horizon seems extremely distant."
def test_suppressed_tokens_minus_1():
model = WhisperModel("tiny.en")
tokenizer = Tokenizer(model.hf_tokenizer, False)
tokens = get_suppressed_tokens(tokenizer, [-1])
assert tokens == (
1,
2,
7,
8,
9,
10,
14,
25,
26,
27,
28,
29,
31,
58,
59,
60,
61,
62,
63,
90,
91,
92,
93,
357,
366,
438,
532,
685,
705,
796,
930,
1058,
1220,
1267,
1279,
1303,
1343,
1377,
1391,
1635,
1782,
1875,
2162,
2361,
2488,
3467,
4008,
4211,
4600,
4808,
5299,
5855,
6329,
7203,
9609,
9959,
10563,
10786,
11420,
11709,
11907,
13163,
13697,
13700,
14808,
15306,
16410,
16791,
17992,
19203,
19510,
20724,
22305,
22935,
27007,
30109,
30420,
33409,
34949,
40283,
40493,
40549,
47282,
49146,
50257,
50357,
50358,
50359,
50360,
)
def test_suppressed_tokens_minus_value():
model = WhisperModel("tiny.en")
tokenizer = Tokenizer(model.hf_tokenizer, False)
tokens = get_suppressed_tokens(tokenizer, [13])
assert tokens == (13, 50257, 50357, 50358, 50359, 50360)