* Add V3 Support * update conversion example --------- Co-authored-by: oscaarjs <oscar.johansson@conversy.se>
279 lines
6.4 KiB
Python
279 lines
6.4 KiB
Python
import string
|
|
|
|
from functools import cached_property
|
|
from typing import List, Optional, Tuple
|
|
|
|
import tokenizers
|
|
|
|
|
|
class Tokenizer:
|
|
"""Simple wrapper around a tokenizers.Tokenizer."""
|
|
|
|
def __init__(
|
|
self,
|
|
tokenizer: tokenizers.Tokenizer,
|
|
multilingual: bool,
|
|
task: Optional[str] = None,
|
|
language: Optional[str] = None,
|
|
):
|
|
self.tokenizer = tokenizer
|
|
|
|
if multilingual:
|
|
if task not in _TASKS:
|
|
raise ValueError(
|
|
"'%s' is not a valid task (accepted tasks: %s)"
|
|
% (task, ", ".join(_TASKS))
|
|
)
|
|
|
|
if language not in _LANGUAGE_CODES:
|
|
raise ValueError(
|
|
"'%s' is not a valid language code (accepted language codes: %s)"
|
|
% (language, ", ".join(_LANGUAGE_CODES))
|
|
)
|
|
|
|
self.task = self.tokenizer.token_to_id("<|%s|>" % task)
|
|
self.language = self.tokenizer.token_to_id("<|%s|>" % language)
|
|
self.language_code = language
|
|
else:
|
|
self.task = None
|
|
self.language = None
|
|
self.language_code = "en"
|
|
|
|
@cached_property
|
|
def transcribe(self) -> int:
|
|
return self.tokenizer.token_to_id("<|transcribe|>")
|
|
|
|
@cached_property
|
|
def translate(self) -> int:
|
|
return self.tokenizer.token_to_id("<|translate|>")
|
|
|
|
@cached_property
|
|
def sot(self) -> int:
|
|
return self.tokenizer.token_to_id("<|startoftranscript|>")
|
|
|
|
@cached_property
|
|
def sot_lm(self) -> int:
|
|
return self.tokenizer.token_to_id("<|startoflm|>")
|
|
|
|
@cached_property
|
|
def sot_prev(self) -> int:
|
|
return self.tokenizer.token_to_id("<|startofprev|>")
|
|
|
|
@cached_property
|
|
def eot(self) -> int:
|
|
return self.tokenizer.token_to_id("<|endoftext|>")
|
|
|
|
@cached_property
|
|
def no_timestamps(self) -> int:
|
|
return self.tokenizer.token_to_id("<|notimestamps|>")
|
|
|
|
@property
|
|
def timestamp_begin(self) -> int:
|
|
return self.no_timestamps + 1
|
|
|
|
@property
|
|
def sot_sequence(self) -> List[int]:
|
|
sequence = [self.sot]
|
|
|
|
if self.language is not None:
|
|
sequence.append(self.language)
|
|
|
|
if self.task is not None:
|
|
sequence.append(self.task)
|
|
|
|
return sequence
|
|
|
|
def encode(self, text: str) -> List[int]:
|
|
return self.tokenizer.encode(text, add_special_tokens=False).ids
|
|
|
|
def decode(self, tokens: List[int]) -> str:
|
|
text_tokens = [token for token in tokens if token < self.eot]
|
|
return self.tokenizer.decode(text_tokens)
|
|
|
|
def decode_with_timestamps(self, tokens: List[int]) -> str:
|
|
outputs = [[]]
|
|
|
|
for token in tokens:
|
|
if token >= self.timestamp_begin:
|
|
timestamp = f"<|{(token - self.timestamp_begin) * 0.02:.2f}|>"
|
|
outputs.append(timestamp)
|
|
outputs.append([])
|
|
else:
|
|
outputs[-1].append(token)
|
|
|
|
return "".join(
|
|
[s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs]
|
|
)
|
|
|
|
def split_to_word_tokens(
|
|
self, tokens: List[int]
|
|
) -> Tuple[List[str], List[List[int]]]:
|
|
if self.language_code in {"zh", "ja", "th", "lo", "my", "yue"}:
|
|
# These languages don't typically use spaces, so it is difficult to split words
|
|
# without morpheme analysis. Here, we instead split words at any
|
|
# position where the tokens are decoded as valid unicode points
|
|
return self.split_tokens_on_unicode(tokens)
|
|
|
|
return self.split_tokens_on_spaces(tokens)
|
|
|
|
def split_tokens_on_unicode(
|
|
self, tokens: List[int]
|
|
) -> Tuple[List[str], List[List[int]]]:
|
|
decoded_full = self.decode_with_timestamps(tokens)
|
|
replacement_char = "\ufffd"
|
|
|
|
words = []
|
|
word_tokens = []
|
|
current_tokens = []
|
|
unicode_offset = 0
|
|
|
|
for token in tokens:
|
|
current_tokens.append(token)
|
|
decoded = self.decode_with_timestamps(current_tokens)
|
|
|
|
try:
|
|
replacement_char_index = decoded.index(replacement_char)
|
|
replacement_char_index += unicode_offset
|
|
except ValueError:
|
|
replacement_char_index = None
|
|
|
|
if replacement_char_index is None or (
|
|
replacement_char_index < len(decoded_full)
|
|
and decoded_full[replacement_char_index] == replacement_char
|
|
):
|
|
words.append(decoded)
|
|
word_tokens.append(current_tokens)
|
|
current_tokens = []
|
|
unicode_offset += len(decoded)
|
|
|
|
return words, word_tokens
|
|
|
|
def split_tokens_on_spaces(
|
|
self, tokens: List[int]
|
|
) -> Tuple[List[str], List[List[int]]]:
|
|
subwords, subword_tokens_list = self.split_tokens_on_unicode(tokens)
|
|
words = []
|
|
word_tokens = []
|
|
|
|
for subword, subword_tokens in zip(subwords, subword_tokens_list):
|
|
special = subword_tokens[0] >= self.eot
|
|
with_space = subword.startswith(" ")
|
|
punctuation = subword.strip() in string.punctuation
|
|
if special or with_space or punctuation or len(words) == 0:
|
|
words.append(subword)
|
|
word_tokens.append(subword_tokens)
|
|
else:
|
|
words[-1] = words[-1] + subword
|
|
word_tokens[-1].extend(subword_tokens)
|
|
|
|
return words, word_tokens
|
|
|
|
|
|
_TASKS = (
|
|
"transcribe",
|
|
"translate",
|
|
)
|
|
|
|
_LANGUAGE_CODES = (
|
|
"af",
|
|
"am",
|
|
"ar",
|
|
"as",
|
|
"az",
|
|
"ba",
|
|
"be",
|
|
"bg",
|
|
"bn",
|
|
"bo",
|
|
"br",
|
|
"bs",
|
|
"ca",
|
|
"cs",
|
|
"cy",
|
|
"da",
|
|
"de",
|
|
"el",
|
|
"en",
|
|
"es",
|
|
"et",
|
|
"eu",
|
|
"fa",
|
|
"fi",
|
|
"fo",
|
|
"fr",
|
|
"gl",
|
|
"gu",
|
|
"ha",
|
|
"haw",
|
|
"he",
|
|
"hi",
|
|
"hr",
|
|
"ht",
|
|
"hu",
|
|
"hy",
|
|
"id",
|
|
"is",
|
|
"it",
|
|
"ja",
|
|
"jw",
|
|
"ka",
|
|
"kk",
|
|
"km",
|
|
"kn",
|
|
"ko",
|
|
"la",
|
|
"lb",
|
|
"ln",
|
|
"lo",
|
|
"lt",
|
|
"lv",
|
|
"mg",
|
|
"mi",
|
|
"mk",
|
|
"ml",
|
|
"mn",
|
|
"mr",
|
|
"ms",
|
|
"mt",
|
|
"my",
|
|
"ne",
|
|
"nl",
|
|
"nn",
|
|
"no",
|
|
"oc",
|
|
"pa",
|
|
"pl",
|
|
"ps",
|
|
"pt",
|
|
"ro",
|
|
"ru",
|
|
"sa",
|
|
"sd",
|
|
"si",
|
|
"sk",
|
|
"sl",
|
|
"sn",
|
|
"so",
|
|
"sq",
|
|
"sr",
|
|
"su",
|
|
"sv",
|
|
"sw",
|
|
"ta",
|
|
"te",
|
|
"tg",
|
|
"th",
|
|
"tk",
|
|
"tl",
|
|
"tr",
|
|
"tt",
|
|
"uk",
|
|
"ur",
|
|
"uz",
|
|
"vi",
|
|
"yi",
|
|
"yo",
|
|
"zh",
|
|
"yue",
|
|
)
|