Create a helper class Tokenizer

This commit is contained in:
Guillaume Klein
2023-03-09 12:53:49 +01:00
parent f0a21ea916
commit c52adaca90
2 changed files with 104 additions and 50 deletions

View File

@@ -0,0 +1,69 @@
from functools import cached_property
from typing import List, Optional
import tokenizers
class Tokenizer:
"""Simple wrapper around a tokenizers.Tokenizer."""
def __init__(
self,
tokenizer: tokenizers.Tokenizer,
multilingual: bool,
task: Optional[str] = None,
language: Optional[str] = None,
):
self.tokenizer = tokenizer
if multilingual:
self.task = self.tokenizer.token_to_id("<|%s|>" % task)
if self.task is None:
raise ValueError("%s is not a valid task" % task)
self.language = self.tokenizer.token_to_id("<|%s|>" % language)
if self.language is None:
raise ValueError("%s is not a valid language code" % language)
else:
self.task = None
self.language = None
@cached_property
def sot(self) -> int:
return self.tokenizer.token_to_id("<|startoftranscript|>")
@cached_property
def sot_prev(self) -> int:
return self.tokenizer.token_to_id("<|startofprev|>")
@cached_property
def eot(self) -> int:
return self.tokenizer.token_to_id("<|endoftext|>")
@cached_property
def no_timestamps(self) -> int:
return self.tokenizer.token_to_id("<|notimestamps|>")
@property
def timestamp_begin(self) -> int:
return self.no_timestamps + 1
@property
def sot_sequence(self) -> List[int]:
sequence = [self.sot]
if self.language is not None:
sequence.append(self.language)
if self.task is not None:
sequence.append(self.task)
return sequence
def encode(self, text: str) -> List[int]:
return self.tokenizer.encode(text, add_special_tokens=False).ids
def decode(self, tokens: List[int]) -> str:
text_tokens = [token for token in tokens if token < self.eot]
return self.tokenizer.decode(text_tokens)

View File

@@ -10,6 +10,7 @@ import tokenizers
from faster_whisper.audio import decode_audio from faster_whisper.audio import decode_audio
from faster_whisper.feature_extractor import FeatureExtractor from faster_whisper.feature_extractor import FeatureExtractor
from faster_whisper.tokenizer import Tokenizer
class Segment(collections.namedtuple("Segment", ("start", "end", "text"))): class Segment(collections.namedtuple("Segment", ("start", "end", "text"))):
@@ -26,8 +27,6 @@ class TranscriptionOptions(
collections.namedtuple( collections.namedtuple(
"TranscriptionOptions", "TranscriptionOptions",
( (
"language",
"task",
"beam_size", "beam_size",
"best_of", "best_of",
"patience", "patience",
@@ -88,15 +87,13 @@ class WhisperModel:
tokenizer_file = os.path.join(model_path, "tokenizer.json") tokenizer_file = os.path.join(model_path, "tokenizer.json")
if os.path.isfile(tokenizer_file): if os.path.isfile(tokenizer_file):
self.tokenizer = tokenizers.Tokenizer.from_file(tokenizer_file) self.hf_tokenizer = tokenizers.Tokenizer.from_file(tokenizer_file)
else: else:
self.tokenizer = tokenizers.Tokenizer.from_pretrained( self.hf_tokenizer = tokenizers.Tokenizer.from_pretrained(
"openai/whisper-tiny" + ("" if self.model.is_multilingual else ".en") "openai/whisper-tiny" + ("" if self.model.is_multilingual else ".en")
) )
self.feature_extractor = FeatureExtractor() self.feature_extractor = FeatureExtractor()
self.eot_id = self.tokenizer.token_to_id("<|endoftext|>")
self.timestamp_begin_id = self.tokenizer.token_to_id("<|notimestamps|>") + 1
self.input_stride = 2 self.input_stride = 2
self.time_precision = 0.02 self.time_precision = 0.02
self.max_length = 448 self.max_length = 448
@@ -187,13 +184,16 @@ class WhisperModel:
language_token, language_probability = results[0][0] language_token, language_probability = results[0][0]
language = language_token[2:-2] language = language_token[2:-2]
else: else:
if self.tokenizer.token_to_id("<|%s|>" % language) is None:
raise ValueError("%s is not a valid language code" % language)
language_probability = 1 language_probability = 1
options = TranscriptionOptions( tokenizer = Tokenizer(
language=language, self.hf_tokenizer,
self.model.is_multilingual,
task=task, task=task,
language=language,
)
options = TranscriptionOptions(
beam_size=beam_size, beam_size=beam_size,
best_of=best_of, best_of=best_of,
patience=patience, patience=patience,
@@ -213,7 +213,7 @@ class WhisperModel:
max_initial_timestamp=max_initial_timestamp, max_initial_timestamp=max_initial_timestamp,
) )
segments = self.generate_segments(features, options) segments = self.generate_segments(features, tokenizer, options)
audio_info = AudioInfo( audio_info = AudioInfo(
language=language, language=language,
@@ -222,7 +222,7 @@ class WhisperModel:
return segments, audio_info return segments, audio_info
def generate_segments(self, features, options): def generate_segments(self, features, tokenizer, options):
content_frames = features.shape[-1] - self.feature_extractor.nb_max_frames content_frames = features.shape[-1] - self.feature_extractor.nb_max_frames
seek = 0 seek = 0
all_tokens = [] all_tokens = []
@@ -230,7 +230,7 @@ class WhisperModel:
if options.initial_prompt is not None: if options.initial_prompt is not None:
initial_prompt = " " + options.initial_prompt.strip() initial_prompt = " " + options.initial_prompt.strip()
initial_prompt_tokens = self.encode_text(initial_prompt) initial_prompt_tokens = tokenizer.encode(initial_prompt)
all_tokens.extend(initial_prompt_tokens) all_tokens.extend(initial_prompt_tokens)
while seek < content_frames: while seek < content_frames:
@@ -243,15 +243,14 @@ class WhisperModel:
previous_tokens = all_tokens[prompt_reset_since:] previous_tokens = all_tokens[prompt_reset_since:]
prompt = self.get_prompt( prompt = self.get_prompt(
options.language, tokenizer,
previous_tokens, previous_tokens,
task=options.task,
without_timestamps=options.without_timestamps, without_timestamps=options.without_timestamps,
prefix=options.prefix, prefix=options.prefix,
) )
result, avg_log_prob, temperature = self.generate_with_fallback( result, avg_log_prob, temperature = self.generate_with_fallback(
segment, prompt, options segment, prompt, tokenizer, options
) )
if options.no_speech_threshold is not None: if options.no_speech_threshold is not None:
@@ -276,16 +275,16 @@ class WhisperModel:
single_timestamp_ending = ( single_timestamp_ending = (
len(tokens) >= 2 len(tokens) >= 2
and tokens[-2] < self.timestamp_begin_id and tokens[-2] < tokenizer.timestamp_begin
and tokens[-1] >= self.timestamp_begin_id and tokens[-1] >= tokenizer.timestamp_begin
) )
consecutive_timestamps = [ consecutive_timestamps = [
i i
for i in range(len(tokens)) for i in range(len(tokens))
if i > 0 if i > 0
and tokens[i] >= self.timestamp_begin_id and tokens[i] >= tokenizer.timestamp_begin
and tokens[i - 1] >= self.timestamp_begin_id and tokens[i - 1] >= tokenizer.timestamp_begin
] ]
if len(consecutive_timestamps) > 0: if len(consecutive_timestamps) > 0:
@@ -297,9 +296,11 @@ class WhisperModel:
for current_slice in slices: for current_slice in slices:
sliced_tokens = tokens[last_slice:current_slice] sliced_tokens = tokens[last_slice:current_slice]
start_timestamp_position = ( start_timestamp_position = (
sliced_tokens[0] - self.timestamp_begin_id sliced_tokens[0] - tokenizer.timestamp_begin
)
end_timestamp_position = (
sliced_tokens[-1] - tokenizer.timestamp_begin
) )
end_timestamp_position = sliced_tokens[-1] - self.timestamp_begin_id
start_time = ( start_time = (
time_offset + start_timestamp_position * self.time_precision time_offset + start_timestamp_position * self.time_precision
) )
@@ -318,17 +319,17 @@ class WhisperModel:
else: else:
# otherwise, ignore the unfinished segment and seek to the last timestamp # otherwise, ignore the unfinished segment and seek to the last timestamp
last_timestamp_position = ( last_timestamp_position = (
tokens[last_slice - 1] - self.timestamp_begin_id tokens[last_slice - 1] - tokenizer.timestamp_begin
) )
seek += last_timestamp_position * self.input_stride seek += last_timestamp_position * self.input_stride
else: else:
duration = segment_duration duration = segment_duration
timestamps = [ timestamps = [
token for token in tokens if token >= self.timestamp_begin_id token for token in tokens if token >= tokenizer.timestamp_begin
] ]
if len(timestamps) > 0 and timestamps[-1] != self.timestamp_begin_id: if len(timestamps) > 0 and timestamps[-1] != tokenizer.timestamp_begin:
last_timestamp_position = timestamps[-1] - self.timestamp_begin_id last_timestamp_position = timestamps[-1] - tokenizer.timestamp_begin
duration = last_timestamp_position * self.time_precision duration = last_timestamp_position * self.time_precision
current_segments.append( current_segments.append(
@@ -344,7 +345,7 @@ class WhisperModel:
tokens = segment["tokens"] tokens = segment["tokens"]
all_tokens.extend(tokens) all_tokens.extend(tokens)
text = self.decode_text_tokens(tokens) text = tokenizer.decode(tokens)
if not text.strip(): if not text.strip():
continue continue
@@ -354,14 +355,7 @@ class WhisperModel:
text=text, text=text,
) )
def encode_text(self, text): def generate_with_fallback(self, segment, prompt, tokenizer, options):
return self.tokenizer.encode(text, add_special_tokens=False).ids
def decode_text_tokens(self, tokens):
text_tokens = [token for token in tokens if token < self.eot_id]
return self.tokenizer.decode(text_tokens)
def generate_with_fallback(self, segment, prompt, options):
features = self.get_input(segment) features = self.get_input(segment)
result = None result = None
avg_log_prob = None avg_log_prob = None
@@ -406,7 +400,7 @@ class WhisperModel:
cum_log_prob = result.scores[0] * (seq_len**options.length_penalty) cum_log_prob = result.scores[0] * (seq_len**options.length_penalty)
avg_log_prob = cum_log_prob / (seq_len + 1) avg_log_prob = cum_log_prob / (seq_len + 1)
text = self.decode_text_tokens(tokens).strip() text = tokenizer.decode(tokens).strip()
compression_ratio = get_compression_ratio(text) compression_ratio = get_compression_ratio(text)
needs_fallback = False needs_fallback = False
@@ -430,33 +424,24 @@ class WhisperModel:
def get_prompt( def get_prompt(
self, self,
language, tokenizer,
previous_tokens, previous_tokens,
task="transcribe",
without_timestamps=False, without_timestamps=False,
prefix=None, prefix=None,
): ):
prompt = [] prompt = []
if previous_tokens: if previous_tokens:
prompt.append(self.tokenizer.token_to_id("<|startofprev|>")) prompt.append(tokenizer.sot_prev)
prompt.extend(previous_tokens[-(self.max_length // 2 - 1) :]) prompt.extend(previous_tokens[-(self.max_length // 2 - 1) :])
prompt.append(self.tokenizer.token_to_id("<|startoftranscript|>")) prompt.extend(tokenizer.sot_sequence)
if self.model.is_multilingual:
prompt.extend(
[
self.tokenizer.token_to_id("<|%s|>" % language),
self.tokenizer.token_to_id("<|%s|>" % task),
]
)
if without_timestamps: if without_timestamps:
prompt.append(self.tokenizer.token_to_id("<|notimestamps|>")) prompt.append(tokenizer.no_timestamps)
if prefix: if prefix:
prefix_tokens = self.encode_text(" " + prefix.strip()) prefix_tokens = tokenizer.encode(" " + prefix.strip())
if len(prefix_tokens) >= self.max_length // 2: if len(prefix_tokens) >= self.max_length // 2:
prefix_tokens = prefix_tokens[: self.max_length // 2 - 1] prefix_tokens = prefix_tokens[: self.max_length // 2 - 1]
prompt.extend(prefix_tokens) prompt.extend(prefix_tokens)