Use tiktoken (#1044)

* use tiktoken==0.3.0

* formatting

* tuple should be safer

* Update whisper/tokenizer.py

Co-authored-by: Ruhollah Majdoddin <r.majdodin@gmail.com>

* use tiktoken 0.3.1

* reflecting suggestions

* cleanup

* bypassing load_tiktoken_bpe to avoid blobfile dep

---------

Co-authored-by: Ruhollah Majdoddin <r.majdodin@gmail.com>
This commit is contained in:
Jong Wook Kim
2023-03-13 05:34:16 -04:00
committed by GitHub
parent ad3250a846
commit 839639a223
15 changed files with 100601 additions and 100096 deletions

View File

@@ -2,6 +2,4 @@ include requirements.txt
include README.md include README.md
include LICENSE include LICENSE
include whisper/assets/* include whisper/assets/*
include whisper/assets/gpt2/*
include whisper/assets/multilingual/*
include whisper/normalizers/english.json include whisper/normalizers/english.json

View File

@@ -3,5 +3,5 @@ numpy
torch torch
tqdm tqdm
more-itertools more-itertools
transformers>=4.19.0 tiktoken==0.3.1
ffmpeg-python==0.2.0 ffmpeg-python==0.2.0

View File

@@ -4,6 +4,7 @@ import pytest
import torch import torch
import whisper import whisper
from whisper.tokenizer import get_tokenizer
@pytest.mark.parametrize("model_name", whisper.available_models()) @pytest.mark.parametrize("model_name", whisper.available_models())
@@ -24,6 +25,11 @@ def test_transcribe(model_name: str):
assert "your country" in transcription assert "your country" in transcription
assert "do for you" in transcription assert "do for you" in transcription
tokenizer = get_tokenizer(model.is_multilingual)
all_tokens = [t for s in result["segments"] for t in s["tokens"]]
assert tokenizer.decode(all_tokens) == result["text"]
assert tokenizer.decode_with_timestamps(all_tokens).startswith("<|0.00|>")
timing_checked = False timing_checked = False
for segment in result["segments"]: for segment in result["segments"]:
for timing in segment["words"]: for timing in segment["words"]:
@@ -31,7 +37,6 @@ def test_transcribe(model_name: str):
if timing["word"].strip(" ,") == "Americans": if timing["word"].strip(" ,") == "Americans":
assert timing["start"] <= 1.8 assert timing["start"] <= 1.8
assert timing["end"] >= 1.8 assert timing["end"] >= 1.8
print(timing)
timing_checked = True timing_checked = True
assert timing_checked assert timing_checked

50256
whisper/assets/gpt2.tiktoken Normal file

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -1 +0,0 @@
{"bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "unk_token": "<|endoftext|>"}

View File

@@ -1 +0,0 @@
{"unk_token": "<|endoftext|>", "bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "add_prefix_space": false, "model_max_length": 1024, "special_tokens_map_file": null, "name_or_path": "gpt2", "tokenizer_class": "GPT2Tokenizer"}

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load Diff

View File

@@ -1 +0,0 @@
{"<|endoftext|>": 50257}

File diff suppressed because it is too large Load Diff

View File

@@ -1 +0,0 @@
{"bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "unk_token": "<|endoftext|>"}

View File

@@ -1 +0,0 @@
{"unk_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "bos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "eos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "add_prefix_space": false, "model_max_length": 1024, "special_tokens_map_file": null, "name_or_path": "multilingual", "errors": "replace", "tokenizer_class": "GPT2Tokenizer"}

File diff suppressed because one or more lines are too long

View File

@@ -1,12 +1,12 @@
import base64
import os import os
import string import string
from dataclasses import dataclass from dataclasses import dataclass, field
from functools import cached_property, lru_cache from functools import cached_property, lru_cache
from typing import List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple
import numpy as np import tiktoken
import torch from tiktoken_ext.openai_public import gpt2
from transformers import GPT2TokenizerFast
LANGUAGES = { LANGUAGES = {
"en": "english", "en": "english",
@@ -127,74 +127,84 @@ TO_LANGUAGE_CODE = {
} }
@dataclass(frozen=True) @dataclass
class Tokenizer: class Tokenizer:
"""A thin wrapper around `GPT2TokenizerFast` providing quick access to special tokens""" """A thin wrapper around `tiktoken` providing quick access to special tokens"""
tokenizer: "GPT2TokenizerFast" encoding: tiktoken.Encoding
language: Optional[str] language: Optional[str] = None
sot_sequence: Tuple[int] task: Optional[str] = None
sot_sequence: Tuple[int] = ()
special_tokens: Dict[str, int] = field(default_factory=dict)
def __post_init__(self):
for special in self.encoding.special_tokens_set:
special_token = self.encoding.encode_single_token(special)
self.special_tokens[special] = special_token
sot: int = self.special_tokens["<|startoftranscript|>"]
translate: int = self.special_tokens["<|translate|>"]
transcribe: int = self.special_tokens["<|transcribe|>"]
langs = tuple(LANGUAGES.keys())
sot_sequence = [sot]
if self.language is not None:
sot_sequence.append(sot + 1 + langs.index(self.language))
if self.task is not None:
task_token: int = transcribe if self.task == "transcribe" else translate
sot_sequence.append(task_token)
self.sot_sequence = tuple(sot_sequence)
def encode(self, text, **kwargs): def encode(self, text, **kwargs):
return self.tokenizer.encode(text, **kwargs) return self.encoding.encode(text, **kwargs)
def decode( def decode(self, token_ids: List[int], **kwargs) -> str:
self, token_ids: Union[int, List[int], np.ndarray, torch.Tensor], **kwargs token_ids = [t for t in token_ids if t < self.timestamp_begin]
): return self.encoding.decode(token_ids, **kwargs)
return self.tokenizer.decode(token_ids, **kwargs)
def decode_with_timestamps(self, tokens) -> str: def decode_with_timestamps(self, token_ids: List[int], **kwargs) -> str:
""" """
Timestamp tokens are above the special tokens' id range and are ignored by `decode()`. Timestamp tokens are above other special tokens' id range and are ignored by `decode()`.
This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>". This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
""" """
outputs = [[]] return self.encoding.decode(token_ids, **kwargs)
for token in tokens:
if token >= self.timestamp_begin:
timestamp = f"<|{(token - self.timestamp_begin) * 0.02:.2f}|>"
outputs.append(timestamp)
outputs.append([])
else:
outputs[-1].append(token)
return "".join(
[s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs]
)
@cached_property @cached_property
def eot(self) -> int: def eot(self) -> int:
return self.tokenizer.eos_token_id return self.encoding.eot_token
@cached_property @cached_property
def transcribe(self) -> int: def transcribe(self) -> int:
return self._get_single_token_id("<|transcribe|>") return self.special_tokens["<|transcribe|>"]
@cached_property @cached_property
def translate(self) -> int: def translate(self) -> int:
return self._get_single_token_id("<|translate|>") return self.special_tokens["<|translate|>"]
@cached_property @cached_property
def sot(self) -> int: def sot(self) -> int:
return self._get_single_token_id("<|startoftranscript|>") return self.special_tokens["<|startoftranscript|>"]
@cached_property @cached_property
def sot_lm(self) -> int: def sot_lm(self) -> int:
return self._get_single_token_id("<|startoflm|>") return self.special_tokens["<|startoflm|>"]
@cached_property @cached_property
def sot_prev(self) -> int: def sot_prev(self) -> int:
return self._get_single_token_id("<|startofprev|>") return self.special_tokens["<|startofprev|>"]
@cached_property @cached_property
def no_speech(self) -> int: def no_speech(self) -> int:
return self._get_single_token_id("<|nospeech|>") return self.special_tokens["<|nospeech|>"]
@cached_property @cached_property
def no_timestamps(self) -> int: def no_timestamps(self) -> int:
return self._get_single_token_id("<|notimestamps|>") return self.special_tokens["<|notimestamps|>"]
@cached_property @cached_property
def timestamp_begin(self) -> int: def timestamp_begin(self) -> int:
return self.tokenizer.all_special_ids[-1] + 1 return self.special_tokens["<|0.00|>"]
@cached_property @cached_property
def language_token(self) -> int: def language_token(self) -> int:
@@ -202,25 +212,15 @@ class Tokenizer:
if self.language is None: if self.language is None:
raise ValueError("This tokenizer does not have language token configured") raise ValueError("This tokenizer does not have language token configured")
additional_tokens = dict( if token := self.special_tokens.get(f"<|{self.language}|>", None):
zip( return token
self.tokenizer.additional_special_tokens,
self.tokenizer.additional_special_tokens_ids,
)
)
candidate = f"<|{self.language}|>"
if candidate in additional_tokens:
return additional_tokens[candidate]
raise KeyError(f"Language {self.language} not found in tokenizer.") raise KeyError(f"Language {self.language} not found in tokenizer.")
@cached_property @cached_property
def all_language_tokens(self) -> Tuple[int]: def all_language_tokens(self) -> Tuple[int]:
result = [] result = []
for token, token_id in zip( for token, token_id in self.special_tokens.items():
self.tokenizer.additional_special_tokens,
self.tokenizer.additional_special_tokens_ids,
):
if token.strip("<|>") in LANGUAGES: if token.strip("<|>") in LANGUAGES:
result.append(token_id) result.append(token_id)
return tuple(result) return tuple(result)
@@ -258,22 +258,17 @@ class Tokenizer:
assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous) assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
# allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word # allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
result = {self.tokenizer.encode(" -")[0], self.tokenizer.encode(" '")[0]} result = {self.encoding.encode(" -")[0], self.encoding.encode(" '")[0]}
for symbol in symbols + list(miscellaneous): for symbol in symbols + list(miscellaneous):
for tokens in [ for tokens in [
self.tokenizer.encode(symbol), self.encoding.encode(symbol),
self.tokenizer.encode(" " + symbol), self.encoding.encode(" " + symbol),
]: ]:
if len(tokens) == 1 or symbol in miscellaneous: if len(tokens) == 1 or symbol in miscellaneous:
result.add(tokens[0]) result.add(tokens[0])
return tuple(sorted(result)) return tuple(sorted(result))
def _get_single_token_id(self, text) -> int:
tokens = self.tokenizer.encode(text)
assert len(tokens) == 1, f"{text} is not encoded as a single token"
return tokens[0]
def split_to_word_tokens(self, tokens: List[int]): def split_to_word_tokens(self, tokens: List[int]):
if self.language in {"zh", "ja", "th", "lo", "my"}: if self.language in {"zh", "ja", "th", "lo", "my"}:
# These languages don't typically use spaces, so it is difficult to split words # These languages don't typically use spaces, so it is difficult to split words
@@ -318,12 +313,17 @@ class Tokenizer:
@lru_cache(maxsize=None) @lru_cache(maxsize=None)
def build_tokenizer(name: str = "gpt2"): def get_encoding(name: str = "gpt2"):
os.environ["TOKENIZERS_PARALLELISM"] = "false" vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
path = os.path.join(os.path.dirname(__file__), "assets", name) ranks = {
tokenizer = GPT2TokenizerFast.from_pretrained(path) base64.b64decode(token): int(rank)
for token, rank in (line.split() for line in open(vocab_path) if line)
}
n_vocab = len(ranks)
special_tokens = {}
specials = [ specials = [
"<|endoftext|>",
"<|startoftranscript|>", "<|startoftranscript|>",
*[f"<|{lang}|>" for lang in LANGUAGES.keys()], *[f"<|{lang}|>" for lang in LANGUAGES.keys()],
"<|translate|>", "<|translate|>",
@@ -332,18 +332,28 @@ def build_tokenizer(name: str = "gpt2"):
"<|startofprev|>", "<|startofprev|>",
"<|nospeech|>", "<|nospeech|>",
"<|notimestamps|>", "<|notimestamps|>",
*[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
] ]
tokenizer.add_special_tokens(dict(additional_special_tokens=specials)) for token in specials:
return tokenizer special_tokens[token] = n_vocab
n_vocab += 1
return tiktoken.Encoding(
name=os.path.basename(vocab_path),
explicit_n_vocab=n_vocab,
pat_str=gpt2()["pat_str"],
mergeable_ranks=ranks,
special_tokens=special_tokens,
)
@lru_cache(maxsize=None) @lru_cache(maxsize=None)
def get_tokenizer( def get_tokenizer(
multilingual: bool, multilingual: bool,
*, *,
task: Optional[str] = None, # Literal["transcribe", "translate", None]
language: Optional[str] = None, language: Optional[str] = None,
task: Optional[str] = None, # Literal["transcribe", "translate", None]
) -> Tokenizer: ) -> Tokenizer:
if language is not None: if language is not None:
language = language.lower() language = language.lower()
@@ -354,27 +364,14 @@ def get_tokenizer(
raise ValueError(f"Unsupported language: {language}") raise ValueError(f"Unsupported language: {language}")
if multilingual: if multilingual:
tokenizer_name = "multilingual" encoding_name = "multilingual"
task = task or "transcribe"
language = language or "en" language = language or "en"
task = task or "transcribe"
else: else:
tokenizer_name = "gpt2" encoding_name = "gpt2"
task = None
language = None language = None
task = None
tokenizer = build_tokenizer(name=tokenizer_name) encoding = get_encoding(name=encoding_name)
all_special_ids: List[int] = tokenizer.all_special_ids
sot: int = all_special_ids[1]
translate: int = all_special_ids[-6]
transcribe: int = all_special_ids[-5]
langs = tuple(LANGUAGES.keys()) return Tokenizer(encoding=encoding, language=language, task=task)
sot_sequence = [sot]
if language is not None:
sot_sequence.append(sot + 1 + langs.index(language))
if task is not None:
sot_sequence.append(transcribe if task == "transcribe" else translate)
return Tokenizer(
tokenizer=tokenizer, language=language, sot_sequence=tuple(sot_sequence)
)