Create a helper class Tokenizer
This commit is contained in:
69
faster_whisper/tokenizer.py
Normal file
69
faster_whisper/tokenizer.py
Normal file
@@ -0,0 +1,69 @@
|
||||
from functools import cached_property
|
||||
from typing import List, Optional
|
||||
|
||||
import tokenizers
|
||||
|
||||
|
||||
class Tokenizer:
|
||||
"""Simple wrapper around a tokenizers.Tokenizer."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: tokenizers.Tokenizer,
|
||||
multilingual: bool,
|
||||
task: Optional[str] = None,
|
||||
language: Optional[str] = None,
|
||||
):
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
if multilingual:
|
||||
self.task = self.tokenizer.token_to_id("<|%s|>" % task)
|
||||
if self.task is None:
|
||||
raise ValueError("%s is not a valid task" % task)
|
||||
|
||||
self.language = self.tokenizer.token_to_id("<|%s|>" % language)
|
||||
if self.language is None:
|
||||
raise ValueError("%s is not a valid language code" % language)
|
||||
|
||||
else:
|
||||
self.task = None
|
||||
self.language = None
|
||||
|
||||
@cached_property
|
||||
def sot(self) -> int:
|
||||
return self.tokenizer.token_to_id("<|startoftranscript|>")
|
||||
|
||||
@cached_property
|
||||
def sot_prev(self) -> int:
|
||||
return self.tokenizer.token_to_id("<|startofprev|>")
|
||||
|
||||
@cached_property
|
||||
def eot(self) -> int:
|
||||
return self.tokenizer.token_to_id("<|endoftext|>")
|
||||
|
||||
@cached_property
|
||||
def no_timestamps(self) -> int:
|
||||
return self.tokenizer.token_to_id("<|notimestamps|>")
|
||||
|
||||
@property
|
||||
def timestamp_begin(self) -> int:
|
||||
return self.no_timestamps + 1
|
||||
|
||||
@property
|
||||
def sot_sequence(self) -> List[int]:
|
||||
sequence = [self.sot]
|
||||
|
||||
if self.language is not None:
|
||||
sequence.append(self.language)
|
||||
|
||||
if self.task is not None:
|
||||
sequence.append(self.task)
|
||||
|
||||
return sequence
|
||||
|
||||
def encode(self, text: str) -> List[int]:
|
||||
return self.tokenizer.encode(text, add_special_tokens=False).ids
|
||||
|
||||
def decode(self, tokens: List[int]) -> str:
|
||||
text_tokens = [token for token in tokens if token < self.eot]
|
||||
return self.tokenizer.decode(text_tokens)
|
||||
Reference in New Issue
Block a user