Suppress some special tokens when the default set is not used
This commit is contained in:
@@ -33,10 +33,22 @@ class Tokenizer:
|
|||||||
self.language = None
|
self.language = None
|
||||||
self.language_code = "en"
|
self.language_code = "en"
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def transcribe(self) -> int:
|
||||||
|
return self.tokenizer.token_to_id("<|transcribe|>")
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def translate(self) -> int:
|
||||||
|
return self.tokenizer.token_to_id("<|translate|>")
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def sot(self) -> int:
|
def sot(self) -> int:
|
||||||
return self.tokenizer.token_to_id("<|startoftranscript|>")
|
return self.tokenizer.token_to_id("<|startoftranscript|>")
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def sot_lm(self) -> int:
|
||||||
|
return self.tokenizer.token_to_id("<|startoflm|>")
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def sot_prev(self) -> int:
|
def sot_prev(self) -> int:
|
||||||
return self.tokenizer.token_to_id("<|startofprev|>")
|
return self.tokenizer.token_to_id("<|startofprev|>")
|
||||||
|
|||||||
@@ -242,7 +242,7 @@ class WhisperModel:
|
|||||||
initial_prompt=initial_prompt,
|
initial_prompt=initial_prompt,
|
||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
suppress_blank=suppress_blank,
|
suppress_blank=suppress_blank,
|
||||||
suppress_tokens=suppress_tokens,
|
suppress_tokens=get_suppressed_tokens(tokenizer, suppress_tokens),
|
||||||
without_timestamps=without_timestamps,
|
without_timestamps=without_timestamps,
|
||||||
max_initial_timestamp=max_initial_timestamp,
|
max_initial_timestamp=max_initial_timestamp,
|
||||||
word_timestamps=word_timestamps,
|
word_timestamps=word_timestamps,
|
||||||
@@ -703,6 +703,27 @@ def get_compression_ratio(text: str) -> float:
|
|||||||
return len(text_bytes) / len(zlib.compress(text_bytes))
|
return len(text_bytes) / len(zlib.compress(text_bytes))
|
||||||
|
|
||||||
|
|
||||||
|
def get_suppressed_tokens(tokenizer, suppress_tokens):
|
||||||
|
if not suppress_tokens or -1 in suppress_tokens:
|
||||||
|
return suppress_tokens
|
||||||
|
|
||||||
|
suppress_tokens = list(suppress_tokens)
|
||||||
|
|
||||||
|
# Ensure the following special tokens are suppressed when the user does
|
||||||
|
# not use the default set (-1).
|
||||||
|
suppress_tokens.extend(
|
||||||
|
[
|
||||||
|
tokenizer.transcribe,
|
||||||
|
tokenizer.translate,
|
||||||
|
tokenizer.sot,
|
||||||
|
tokenizer.sot_prev,
|
||||||
|
tokenizer.sot_lm,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
return sorted(set(suppress_tokens))
|
||||||
|
|
||||||
|
|
||||||
def merge_punctuations(alignment: List[dict], prepended: str, appended: str):
|
def merge_punctuations(alignment: List[dict], prepended: str, appended: str):
|
||||||
# merge prepended punctuations
|
# merge prepended punctuations
|
||||||
i = len(alignment) - 2
|
i = len(alignment) - 2
|
||||||
|
|||||||
Reference in New Issue
Block a user