164 lines
5.2 KiB
Python
164 lines
5.2 KiB
Python
import string
|
|
|
|
from functools import cached_property
|
|
from typing import List, Optional, Tuple
|
|
|
|
import tokenizers
|
|
|
|
|
|
class Tokenizer:
|
|
"""Simple wrapper around a tokenizers.Tokenizer."""
|
|
|
|
def __init__(
|
|
self,
|
|
tokenizer: tokenizers.Tokenizer,
|
|
multilingual: bool,
|
|
task: Optional[str] = None,
|
|
language: Optional[str] = None,
|
|
):
|
|
self.tokenizer = tokenizer
|
|
|
|
if multilingual:
|
|
self.task = self.tokenizer.token_to_id("<|%s|>" % task)
|
|
if self.task is None:
|
|
raise ValueError("%s is not a valid task" % task)
|
|
|
|
self.language_code = language
|
|
self.language = self.tokenizer.token_to_id("<|%s|>" % language)
|
|
if self.language is None:
|
|
raise ValueError("%s is not a valid language code" % language)
|
|
|
|
else:
|
|
self.task = None
|
|
self.language = None
|
|
self.language_code = "en"
|
|
|
|
@cached_property
|
|
def transcribe(self) -> int:
|
|
return self.tokenizer.token_to_id("<|transcribe|>")
|
|
|
|
@cached_property
|
|
def translate(self) -> int:
|
|
return self.tokenizer.token_to_id("<|translate|>")
|
|
|
|
@cached_property
|
|
def sot(self) -> int:
|
|
return self.tokenizer.token_to_id("<|startoftranscript|>")
|
|
|
|
@cached_property
|
|
def sot_lm(self) -> int:
|
|
return self.tokenizer.token_to_id("<|startoflm|>")
|
|
|
|
@cached_property
|
|
def sot_prev(self) -> int:
|
|
return self.tokenizer.token_to_id("<|startofprev|>")
|
|
|
|
@cached_property
|
|
def eot(self) -> int:
|
|
return self.tokenizer.token_to_id("<|endoftext|>")
|
|
|
|
@cached_property
|
|
def no_timestamps(self) -> int:
|
|
return self.tokenizer.token_to_id("<|notimestamps|>")
|
|
|
|
@property
|
|
def timestamp_begin(self) -> int:
|
|
return self.no_timestamps + 1
|
|
|
|
@property
|
|
def sot_sequence(self) -> List[int]:
|
|
sequence = [self.sot]
|
|
|
|
if self.language is not None:
|
|
sequence.append(self.language)
|
|
|
|
if self.task is not None:
|
|
sequence.append(self.task)
|
|
|
|
return sequence
|
|
|
|
def encode(self, text: str) -> List[int]:
|
|
return self.tokenizer.encode(text, add_special_tokens=False).ids
|
|
|
|
def decode(self, tokens: List[int]) -> str:
|
|
text_tokens = [token for token in tokens if token < self.eot]
|
|
return self.tokenizer.decode(text_tokens)
|
|
|
|
def decode_with_timestamps(self, tokens: List[int]) -> str:
|
|
outputs = [[]]
|
|
|
|
for token in tokens:
|
|
if token >= self.timestamp_begin:
|
|
timestamp = f"<|{(token - self.timestamp_begin) * 0.02:.2f}|>"
|
|
outputs.append(timestamp)
|
|
outputs.append([])
|
|
else:
|
|
outputs[-1].append(token)
|
|
|
|
return "".join(
|
|
[s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs]
|
|
)
|
|
|
|
def split_to_word_tokens(
|
|
self, tokens: List[int]
|
|
) -> Tuple[List[str], List[List[int]]]:
|
|
if self.language_code in {"zh", "ja", "th", "lo", "my"}:
|
|
# These languages don't typically use spaces, so it is difficult to split words
|
|
# without morpheme analysis. Here, we instead split words at any
|
|
# position where the tokens are decoded as valid unicode points
|
|
return self.split_tokens_on_unicode(tokens)
|
|
|
|
return self.split_tokens_on_spaces(tokens)
|
|
|
|
def split_tokens_on_unicode(
|
|
self, tokens: List[int]
|
|
) -> Tuple[List[str], List[List[int]]]:
|
|
decoded_full = self.decode_with_timestamps(tokens)
|
|
replacement_char = "\ufffd"
|
|
|
|
words = []
|
|
word_tokens = []
|
|
current_tokens = []
|
|
unicode_offset = 0
|
|
|
|
for token in tokens:
|
|
current_tokens.append(token)
|
|
decoded = self.decode_with_timestamps(current_tokens)
|
|
|
|
try:
|
|
replacement_char_index = decoded.index(replacement_char)
|
|
replacement_char_index += unicode_offset
|
|
except ValueError:
|
|
replacement_char_index = None
|
|
|
|
if replacement_char_index is None or (
|
|
replacement_char_index < len(decoded_full)
|
|
and decoded_full[replacement_char_index] == replacement_char
|
|
):
|
|
words.append(decoded)
|
|
word_tokens.append(current_tokens)
|
|
current_tokens = []
|
|
unicode_offset += len(decoded)
|
|
|
|
return words, word_tokens
|
|
|
|
def split_tokens_on_spaces(
|
|
self, tokens: List[int]
|
|
) -> Tuple[List[str], List[List[int]]]:
|
|
subwords, subword_tokens_list = self.split_tokens_on_unicode(tokens)
|
|
words = []
|
|
word_tokens = []
|
|
|
|
for subword, subword_tokens in zip(subwords, subword_tokens_list):
|
|
special = subword_tokens[0] >= self.eot
|
|
with_space = subword.startswith(" ")
|
|
punctuation = subword.strip() in string.punctuation
|
|
if special or with_space or punctuation or len(words) == 0:
|
|
words.append(subword)
|
|
word_tokens.append(subword_tokens)
|
|
else:
|
|
words[-1] = words[-1] + subword
|
|
word_tokens[-1].extend(subword_tokens)
|
|
|
|
return words, word_tokens
|