Filter out non_speech_tokens in suppressed tokens (#898)
* Filter out non_speech_tokens in suppressed tokens
This commit is contained in:
@@ -105,6 +105,42 @@ class Tokenizer:
|
|||||||
[s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs]
|
[s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def non_speech_tokens(self) -> Tuple[int]:
|
||||||
|
"""
|
||||||
|
Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
|
||||||
|
annotations, to prevent sampling texts that are not actually spoken in the audio, e.g.
|
||||||
|
|
||||||
|
- ♪♪♪
|
||||||
|
- ( SPEAKING FOREIGN LANGUAGE )
|
||||||
|
- [DAVID] Hey there,
|
||||||
|
|
||||||
|
keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
|
||||||
|
"""
|
||||||
|
symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』')
|
||||||
|
symbols += (
|
||||||
|
"<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
|
||||||
|
)
|
||||||
|
|
||||||
|
# symbols that may be a single token or multiple tokens depending on the tokenizer.
|
||||||
|
# In case they're multiple tokens, suppress the first token, which is safe because:
|
||||||
|
# These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress
|
||||||
|
# in generations, and in the 3-byte UTF-8 representation they share the first two bytes.
|
||||||
|
miscellaneous = set("♩♪♫♬♭♮♯")
|
||||||
|
assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
|
||||||
|
|
||||||
|
# allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
|
||||||
|
result = {self.encode(" -")[0], self.encode(" '")[0]}
|
||||||
|
for symbol in symbols + list(miscellaneous):
|
||||||
|
for tokens in [
|
||||||
|
self.encode(symbol),
|
||||||
|
self.encode(" " + symbol),
|
||||||
|
]:
|
||||||
|
if len(tokens) == 1 or symbol in miscellaneous:
|
||||||
|
result.add(tokens[0])
|
||||||
|
|
||||||
|
return tuple(sorted(result))
|
||||||
|
|
||||||
def split_to_word_tokens(
|
def split_to_word_tokens(
|
||||||
self, tokens: List[int]
|
self, tokens: List[int]
|
||||||
) -> Tuple[List[str], List[List[int]]]:
|
) -> Tuple[List[str], List[List[int]]]:
|
||||||
|
|||||||
@@ -277,7 +277,7 @@ class WhisperModel:
|
|||||||
prefix: Optional text to provide as a prefix for the first window.
|
prefix: Optional text to provide as a prefix for the first window.
|
||||||
suppress_blank: Suppress blank outputs at the beginning of the sampling.
|
suppress_blank: Suppress blank outputs at the beginning of the sampling.
|
||||||
suppress_tokens: List of token IDs to suppress. -1 will suppress a default set
|
suppress_tokens: List of token IDs to suppress. -1 will suppress a default set
|
||||||
of symbols as defined in the model config.json file.
|
of symbols as defined in `tokenizer.non_speech_tokens()`
|
||||||
without_timestamps: Only sample text tokens.
|
without_timestamps: Only sample text tokens.
|
||||||
max_initial_timestamp: The initial timestamp cannot be later than this.
|
max_initial_timestamp: The initial timestamp cannot be later than this.
|
||||||
word_timestamps: Extract word-level timestamps using the cross-attention pattern
|
word_timestamps: Extract word-level timestamps using the cross-attention pattern
|
||||||
@@ -462,7 +462,11 @@ class WhisperModel:
|
|||||||
initial_prompt=initial_prompt,
|
initial_prompt=initial_prompt,
|
||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
suppress_blank=suppress_blank,
|
suppress_blank=suppress_blank,
|
||||||
suppress_tokens=get_suppressed_tokens(tokenizer, suppress_tokens),
|
suppress_tokens=(
|
||||||
|
get_suppressed_tokens(tokenizer, suppress_tokens)
|
||||||
|
if suppress_tokens
|
||||||
|
else suppress_tokens
|
||||||
|
),
|
||||||
without_timestamps=without_timestamps,
|
without_timestamps=without_timestamps,
|
||||||
max_initial_timestamp=max_initial_timestamp,
|
max_initial_timestamp=max_initial_timestamp,
|
||||||
word_timestamps=word_timestamps,
|
word_timestamps=word_timestamps,
|
||||||
@@ -488,7 +492,6 @@ class WhisperModel:
|
|||||||
vad_options=vad_parameters,
|
vad_options=vad_parameters,
|
||||||
all_language_probs=all_language_probs,
|
all_language_probs=all_language_probs,
|
||||||
)
|
)
|
||||||
|
|
||||||
return segments, info
|
return segments, info
|
||||||
|
|
||||||
def generate_segments(
|
def generate_segments(
|
||||||
@@ -1227,15 +1230,16 @@ def get_compression_ratio(text: str) -> float:
|
|||||||
|
|
||||||
def get_suppressed_tokens(
|
def get_suppressed_tokens(
|
||||||
tokenizer: Tokenizer,
|
tokenizer: Tokenizer,
|
||||||
suppress_tokens: Optional[List[int]],
|
suppress_tokens: Tuple[int],
|
||||||
) -> Optional[List[int]]:
|
) -> Optional[List[int]]:
|
||||||
if not suppress_tokens or -1 in suppress_tokens:
|
if -1 in suppress_tokens:
|
||||||
return suppress_tokens
|
suppress_tokens = [t for t in suppress_tokens if t >= 0]
|
||||||
|
suppress_tokens.extend(tokenizer.non_speech_tokens)
|
||||||
|
elif suppress_tokens is None or len(suppress_tokens) == 0:
|
||||||
|
suppress_tokens = [] # interpret empty string as an empty list
|
||||||
|
else:
|
||||||
|
assert isinstance(suppress_tokens, list), "suppress_tokens must be a list"
|
||||||
|
|
||||||
suppress_tokens = list(suppress_tokens)
|
|
||||||
|
|
||||||
# Ensure the following special tokens are suppressed when the user does
|
|
||||||
# not use the default set (-1).
|
|
||||||
suppress_tokens.extend(
|
suppress_tokens.extend(
|
||||||
[
|
[
|
||||||
tokenizer.transcribe,
|
tokenizer.transcribe,
|
||||||
@@ -1246,7 +1250,7 @@ def get_suppressed_tokens(
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
return sorted(set(suppress_tokens))
|
return tuple(sorted(set(suppress_tokens)))
|
||||||
|
|
||||||
|
|
||||||
def merge_punctuations(alignment: List[dict], prepended: str, appended: str) -> None:
|
def merge_punctuations(alignment: List[dict], prepended: str, appended: str) -> None:
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
from faster_whisper import WhisperModel, decode_audio
|
from faster_whisper import WhisperModel, decode_audio
|
||||||
|
from faster_whisper.tokenizer import Tokenizer
|
||||||
|
from faster_whisper.transcribe import get_suppressed_tokens
|
||||||
|
|
||||||
|
|
||||||
def test_supported_languages():
|
def test_supported_languages():
|
||||||
@@ -97,3 +99,109 @@ def test_stereo_diarization(data_dir):
|
|||||||
segments, _ = model.transcribe(right)
|
segments, _ = model.transcribe(right)
|
||||||
transcription = "".join(segment.text for segment in segments).strip()
|
transcription = "".join(segment.text for segment in segments).strip()
|
||||||
assert transcription == "The horizon seems extremely distant."
|
assert transcription == "The horizon seems extremely distant."
|
||||||
|
|
||||||
|
|
||||||
|
def test_suppressed_tokens_minus_1():
|
||||||
|
model = WhisperModel("tiny.en")
|
||||||
|
|
||||||
|
tokenizer = Tokenizer(model.hf_tokenizer, False)
|
||||||
|
tokens = get_suppressed_tokens(tokenizer, [-1])
|
||||||
|
assert tokens == (
|
||||||
|
1,
|
||||||
|
2,
|
||||||
|
7,
|
||||||
|
8,
|
||||||
|
9,
|
||||||
|
10,
|
||||||
|
14,
|
||||||
|
25,
|
||||||
|
26,
|
||||||
|
27,
|
||||||
|
28,
|
||||||
|
29,
|
||||||
|
31,
|
||||||
|
58,
|
||||||
|
59,
|
||||||
|
60,
|
||||||
|
61,
|
||||||
|
62,
|
||||||
|
63,
|
||||||
|
90,
|
||||||
|
91,
|
||||||
|
92,
|
||||||
|
93,
|
||||||
|
357,
|
||||||
|
366,
|
||||||
|
438,
|
||||||
|
532,
|
||||||
|
685,
|
||||||
|
705,
|
||||||
|
796,
|
||||||
|
930,
|
||||||
|
1058,
|
||||||
|
1220,
|
||||||
|
1267,
|
||||||
|
1279,
|
||||||
|
1303,
|
||||||
|
1343,
|
||||||
|
1377,
|
||||||
|
1391,
|
||||||
|
1635,
|
||||||
|
1782,
|
||||||
|
1875,
|
||||||
|
2162,
|
||||||
|
2361,
|
||||||
|
2488,
|
||||||
|
3467,
|
||||||
|
4008,
|
||||||
|
4211,
|
||||||
|
4600,
|
||||||
|
4808,
|
||||||
|
5299,
|
||||||
|
5855,
|
||||||
|
6329,
|
||||||
|
7203,
|
||||||
|
9609,
|
||||||
|
9959,
|
||||||
|
10563,
|
||||||
|
10786,
|
||||||
|
11420,
|
||||||
|
11709,
|
||||||
|
11907,
|
||||||
|
13163,
|
||||||
|
13697,
|
||||||
|
13700,
|
||||||
|
14808,
|
||||||
|
15306,
|
||||||
|
16410,
|
||||||
|
16791,
|
||||||
|
17992,
|
||||||
|
19203,
|
||||||
|
19510,
|
||||||
|
20724,
|
||||||
|
22305,
|
||||||
|
22935,
|
||||||
|
27007,
|
||||||
|
30109,
|
||||||
|
30420,
|
||||||
|
33409,
|
||||||
|
34949,
|
||||||
|
40283,
|
||||||
|
40493,
|
||||||
|
40549,
|
||||||
|
47282,
|
||||||
|
49146,
|
||||||
|
50257,
|
||||||
|
50357,
|
||||||
|
50358,
|
||||||
|
50359,
|
||||||
|
50360,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_suppressed_tokens_minus_value():
|
||||||
|
model = WhisperModel("tiny.en")
|
||||||
|
|
||||||
|
tokenizer = Tokenizer(model.hf_tokenizer, False)
|
||||||
|
tokens = get_suppressed_tokens(tokenizer, [13])
|
||||||
|
assert tokens == (13, 50257, 50357, 50358, 50359, 50360)
|
||||||
|
|||||||
Reference in New Issue
Block a user