70 lines
1.9 KiB
Python
70 lines
1.9 KiB
Python
from functools import cached_property
|
|
from typing import List, Optional
|
|
|
|
import tokenizers
|
|
|
|
|
|
class Tokenizer:
|
|
"""Simple wrapper around a tokenizers.Tokenizer."""
|
|
|
|
def __init__(
|
|
self,
|
|
tokenizer: tokenizers.Tokenizer,
|
|
multilingual: bool,
|
|
task: Optional[str] = None,
|
|
language: Optional[str] = None,
|
|
):
|
|
self.tokenizer = tokenizer
|
|
|
|
if multilingual:
|
|
self.task = self.tokenizer.token_to_id("<|%s|>" % task)
|
|
if self.task is None:
|
|
raise ValueError("%s is not a valid task" % task)
|
|
|
|
self.language = self.tokenizer.token_to_id("<|%s|>" % language)
|
|
if self.language is None:
|
|
raise ValueError("%s is not a valid language code" % language)
|
|
|
|
else:
|
|
self.task = None
|
|
self.language = None
|
|
|
|
@cached_property
|
|
def sot(self) -> int:
|
|
return self.tokenizer.token_to_id("<|startoftranscript|>")
|
|
|
|
@cached_property
|
|
def sot_prev(self) -> int:
|
|
return self.tokenizer.token_to_id("<|startofprev|>")
|
|
|
|
@cached_property
|
|
def eot(self) -> int:
|
|
return self.tokenizer.token_to_id("<|endoftext|>")
|
|
|
|
@cached_property
|
|
def no_timestamps(self) -> int:
|
|
return self.tokenizer.token_to_id("<|notimestamps|>")
|
|
|
|
@property
|
|
def timestamp_begin(self) -> int:
|
|
return self.no_timestamps + 1
|
|
|
|
@property
|
|
def sot_sequence(self) -> List[int]:
|
|
sequence = [self.sot]
|
|
|
|
if self.language is not None:
|
|
sequence.append(self.language)
|
|
|
|
if self.task is not None:
|
|
sequence.append(self.task)
|
|
|
|
return sequence
|
|
|
|
def encode(self, text: str) -> List[int]:
|
|
return self.tokenizer.encode(text, add_special_tokens=False).ids
|
|
|
|
def decode(self, tokens: List[int]) -> str:
|
|
text_tokens = [token for token in tokens if token < self.eot]
|
|
return self.tokenizer.decode(text_tokens)
|