Filter out non_speech_tokens in suppressed tokens (#898)
* Filter out non_speech_tokens in suppressed tokens
This commit is contained in:
@@ -277,7 +277,7 @@ class WhisperModel:
|
||||
prefix: Optional text to provide as a prefix for the first window.
|
||||
suppress_blank: Suppress blank outputs at the beginning of the sampling.
|
||||
suppress_tokens: List of token IDs to suppress. -1 will suppress a default set
|
||||
of symbols as defined in the model config.json file.
|
||||
of symbols as defined in `tokenizer.non_speech_tokens()`
|
||||
without_timestamps: Only sample text tokens.
|
||||
max_initial_timestamp: The initial timestamp cannot be later than this.
|
||||
word_timestamps: Extract word-level timestamps using the cross-attention pattern
|
||||
@@ -462,7 +462,11 @@ class WhisperModel:
|
||||
initial_prompt=initial_prompt,
|
||||
prefix=prefix,
|
||||
suppress_blank=suppress_blank,
|
||||
suppress_tokens=get_suppressed_tokens(tokenizer, suppress_tokens),
|
||||
suppress_tokens=(
|
||||
get_suppressed_tokens(tokenizer, suppress_tokens)
|
||||
if suppress_tokens
|
||||
else suppress_tokens
|
||||
),
|
||||
without_timestamps=without_timestamps,
|
||||
max_initial_timestamp=max_initial_timestamp,
|
||||
word_timestamps=word_timestamps,
|
||||
@@ -488,7 +492,6 @@ class WhisperModel:
|
||||
vad_options=vad_parameters,
|
||||
all_language_probs=all_language_probs,
|
||||
)
|
||||
|
||||
return segments, info
|
||||
|
||||
def generate_segments(
|
||||
@@ -1227,15 +1230,16 @@ def get_compression_ratio(text: str) -> float:
|
||||
|
||||
def get_suppressed_tokens(
|
||||
tokenizer: Tokenizer,
|
||||
suppress_tokens: Optional[List[int]],
|
||||
suppress_tokens: Tuple[int],
|
||||
) -> Optional[List[int]]:
|
||||
if not suppress_tokens or -1 in suppress_tokens:
|
||||
return suppress_tokens
|
||||
if -1 in suppress_tokens:
|
||||
suppress_tokens = [t for t in suppress_tokens if t >= 0]
|
||||
suppress_tokens.extend(tokenizer.non_speech_tokens)
|
||||
elif suppress_tokens is None or len(suppress_tokens) == 0:
|
||||
suppress_tokens = [] # interpret empty string as an empty list
|
||||
else:
|
||||
assert isinstance(suppress_tokens, list), "suppress_tokens must be a list"
|
||||
|
||||
suppress_tokens = list(suppress_tokens)
|
||||
|
||||
# Ensure the following special tokens are suppressed when the user does
|
||||
# not use the default set (-1).
|
||||
suppress_tokens.extend(
|
||||
[
|
||||
tokenizer.transcribe,
|
||||
@@ -1246,7 +1250,7 @@ def get_suppressed_tokens(
|
||||
]
|
||||
)
|
||||
|
||||
return sorted(set(suppress_tokens))
|
||||
return tuple(sorted(set(suppress_tokens)))
|
||||
|
||||
|
||||
def merge_punctuations(alignment: List[dict], prepended: str, appended: str) -> None:
|
||||
|
||||
Reference in New Issue
Block a user