Create a helper class Tokenizer

This commit is contained in:
Guillaume Klein
2023-03-09 12:53:49 +01:00
parent f0a21ea916
commit c52adaca90
2 changed files with 104 additions and 50 deletions

View File

@@ -10,6 +10,7 @@ import tokenizers
from faster_whisper.audio import decode_audio
from faster_whisper.feature_extractor import FeatureExtractor
from faster_whisper.tokenizer import Tokenizer
class Segment(collections.namedtuple("Segment", ("start", "end", "text"))):
@@ -26,8 +27,6 @@ class TranscriptionOptions(
collections.namedtuple(
"TranscriptionOptions",
(
"language",
"task",
"beam_size",
"best_of",
"patience",
@@ -88,15 +87,13 @@ class WhisperModel:
tokenizer_file = os.path.join(model_path, "tokenizer.json")
if os.path.isfile(tokenizer_file):
self.tokenizer = tokenizers.Tokenizer.from_file(tokenizer_file)
self.hf_tokenizer = tokenizers.Tokenizer.from_file(tokenizer_file)
else:
self.tokenizer = tokenizers.Tokenizer.from_pretrained(
self.hf_tokenizer = tokenizers.Tokenizer.from_pretrained(
"openai/whisper-tiny" + ("" if self.model.is_multilingual else ".en")
)
self.feature_extractor = FeatureExtractor()
self.eot_id = self.tokenizer.token_to_id("<|endoftext|>")
self.timestamp_begin_id = self.tokenizer.token_to_id("<|notimestamps|>") + 1
self.input_stride = 2
self.time_precision = 0.02
self.max_length = 448
@@ -187,13 +184,16 @@ class WhisperModel:
language_token, language_probability = results[0][0]
language = language_token[2:-2]
else:
if self.tokenizer.token_to_id("<|%s|>" % language) is None:
raise ValueError("%s is not a valid language code" % language)
language_probability = 1
options = TranscriptionOptions(
language=language,
tokenizer = Tokenizer(
self.hf_tokenizer,
self.model.is_multilingual,
task=task,
language=language,
)
options = TranscriptionOptions(
beam_size=beam_size,
best_of=best_of,
patience=patience,
@@ -213,7 +213,7 @@ class WhisperModel:
max_initial_timestamp=max_initial_timestamp,
)
segments = self.generate_segments(features, options)
segments = self.generate_segments(features, tokenizer, options)
audio_info = AudioInfo(
language=language,
@@ -222,7 +222,7 @@ class WhisperModel:
return segments, audio_info
def generate_segments(self, features, options):
def generate_segments(self, features, tokenizer, options):
content_frames = features.shape[-1] - self.feature_extractor.nb_max_frames
seek = 0
all_tokens = []
@@ -230,7 +230,7 @@ class WhisperModel:
if options.initial_prompt is not None:
initial_prompt = " " + options.initial_prompt.strip()
initial_prompt_tokens = self.encode_text(initial_prompt)
initial_prompt_tokens = tokenizer.encode(initial_prompt)
all_tokens.extend(initial_prompt_tokens)
while seek < content_frames:
@@ -243,15 +243,14 @@ class WhisperModel:
previous_tokens = all_tokens[prompt_reset_since:]
prompt = self.get_prompt(
options.language,
tokenizer,
previous_tokens,
task=options.task,
without_timestamps=options.without_timestamps,
prefix=options.prefix,
)
result, avg_log_prob, temperature = self.generate_with_fallback(
segment, prompt, options
segment, prompt, tokenizer, options
)
if options.no_speech_threshold is not None:
@@ -276,16 +275,16 @@ class WhisperModel:
single_timestamp_ending = (
len(tokens) >= 2
and tokens[-2] < self.timestamp_begin_id
and tokens[-1] >= self.timestamp_begin_id
and tokens[-2] < tokenizer.timestamp_begin
and tokens[-1] >= tokenizer.timestamp_begin
)
consecutive_timestamps = [
i
for i in range(len(tokens))
if i > 0
and tokens[i] >= self.timestamp_begin_id
and tokens[i - 1] >= self.timestamp_begin_id
and tokens[i] >= tokenizer.timestamp_begin
and tokens[i - 1] >= tokenizer.timestamp_begin
]
if len(consecutive_timestamps) > 0:
@@ -297,9 +296,11 @@ class WhisperModel:
for current_slice in slices:
sliced_tokens = tokens[last_slice:current_slice]
start_timestamp_position = (
sliced_tokens[0] - self.timestamp_begin_id
sliced_tokens[0] - tokenizer.timestamp_begin
)
end_timestamp_position = (
sliced_tokens[-1] - tokenizer.timestamp_begin
)
end_timestamp_position = sliced_tokens[-1] - self.timestamp_begin_id
start_time = (
time_offset + start_timestamp_position * self.time_precision
)
@@ -318,17 +319,17 @@ class WhisperModel:
else:
# otherwise, ignore the unfinished segment and seek to the last timestamp
last_timestamp_position = (
tokens[last_slice - 1] - self.timestamp_begin_id
tokens[last_slice - 1] - tokenizer.timestamp_begin
)
seek += last_timestamp_position * self.input_stride
else:
duration = segment_duration
timestamps = [
token for token in tokens if token >= self.timestamp_begin_id
token for token in tokens if token >= tokenizer.timestamp_begin
]
if len(timestamps) > 0 and timestamps[-1] != self.timestamp_begin_id:
last_timestamp_position = timestamps[-1] - self.timestamp_begin_id
if len(timestamps) > 0 and timestamps[-1] != tokenizer.timestamp_begin:
last_timestamp_position = timestamps[-1] - tokenizer.timestamp_begin
duration = last_timestamp_position * self.time_precision
current_segments.append(
@@ -344,7 +345,7 @@ class WhisperModel:
tokens = segment["tokens"]
all_tokens.extend(tokens)
text = self.decode_text_tokens(tokens)
text = tokenizer.decode(tokens)
if not text.strip():
continue
@@ -354,14 +355,7 @@ class WhisperModel:
text=text,
)
def encode_text(self, text):
return self.tokenizer.encode(text, add_special_tokens=False).ids
def decode_text_tokens(self, tokens):
text_tokens = [token for token in tokens if token < self.eot_id]
return self.tokenizer.decode(text_tokens)
def generate_with_fallback(self, segment, prompt, options):
def generate_with_fallback(self, segment, prompt, tokenizer, options):
features = self.get_input(segment)
result = None
avg_log_prob = None
@@ -406,7 +400,7 @@ class WhisperModel:
cum_log_prob = result.scores[0] * (seq_len**options.length_penalty)
avg_log_prob = cum_log_prob / (seq_len + 1)
text = self.decode_text_tokens(tokens).strip()
text = tokenizer.decode(tokens).strip()
compression_ratio = get_compression_ratio(text)
needs_fallback = False
@@ -430,33 +424,24 @@ class WhisperModel:
def get_prompt(
self,
language,
tokenizer,
previous_tokens,
task="transcribe",
without_timestamps=False,
prefix=None,
):
prompt = []
if previous_tokens:
prompt.append(self.tokenizer.token_to_id("<|startofprev|>"))
prompt.append(tokenizer.sot_prev)
prompt.extend(previous_tokens[-(self.max_length // 2 - 1) :])
prompt.append(self.tokenizer.token_to_id("<|startoftranscript|>"))
if self.model.is_multilingual:
prompt.extend(
[
self.tokenizer.token_to_id("<|%s|>" % language),
self.tokenizer.token_to_id("<|%s|>" % task),
]
)
prompt.extend(tokenizer.sot_sequence)
if without_timestamps:
prompt.append(self.tokenizer.token_to_id("<|notimestamps|>"))
prompt.append(tokenizer.no_timestamps)
if prefix:
prefix_tokens = self.encode_text(" " + prefix.strip())
prefix_tokens = tokenizer.encode(" " + prefix.strip())
if len(prefix_tokens) >= self.max_length // 2:
prefix_tokens = prefix_tokens[: self.max_length // 2 - 1]
prompt.extend(prefix_tokens)