Create a helper class Tokenizer
This commit is contained in:
@@ -10,6 +10,7 @@ import tokenizers
|
||||
|
||||
from faster_whisper.audio import decode_audio
|
||||
from faster_whisper.feature_extractor import FeatureExtractor
|
||||
from faster_whisper.tokenizer import Tokenizer
|
||||
|
||||
|
||||
class Segment(collections.namedtuple("Segment", ("start", "end", "text"))):
|
||||
@@ -26,8 +27,6 @@ class TranscriptionOptions(
|
||||
collections.namedtuple(
|
||||
"TranscriptionOptions",
|
||||
(
|
||||
"language",
|
||||
"task",
|
||||
"beam_size",
|
||||
"best_of",
|
||||
"patience",
|
||||
@@ -88,15 +87,13 @@ class WhisperModel:
|
||||
|
||||
tokenizer_file = os.path.join(model_path, "tokenizer.json")
|
||||
if os.path.isfile(tokenizer_file):
|
||||
self.tokenizer = tokenizers.Tokenizer.from_file(tokenizer_file)
|
||||
self.hf_tokenizer = tokenizers.Tokenizer.from_file(tokenizer_file)
|
||||
else:
|
||||
self.tokenizer = tokenizers.Tokenizer.from_pretrained(
|
||||
self.hf_tokenizer = tokenizers.Tokenizer.from_pretrained(
|
||||
"openai/whisper-tiny" + ("" if self.model.is_multilingual else ".en")
|
||||
)
|
||||
|
||||
self.feature_extractor = FeatureExtractor()
|
||||
self.eot_id = self.tokenizer.token_to_id("<|endoftext|>")
|
||||
self.timestamp_begin_id = self.tokenizer.token_to_id("<|notimestamps|>") + 1
|
||||
self.input_stride = 2
|
||||
self.time_precision = 0.02
|
||||
self.max_length = 448
|
||||
@@ -187,13 +184,16 @@ class WhisperModel:
|
||||
language_token, language_probability = results[0][0]
|
||||
language = language_token[2:-2]
|
||||
else:
|
||||
if self.tokenizer.token_to_id("<|%s|>" % language) is None:
|
||||
raise ValueError("%s is not a valid language code" % language)
|
||||
language_probability = 1
|
||||
|
||||
options = TranscriptionOptions(
|
||||
language=language,
|
||||
tokenizer = Tokenizer(
|
||||
self.hf_tokenizer,
|
||||
self.model.is_multilingual,
|
||||
task=task,
|
||||
language=language,
|
||||
)
|
||||
|
||||
options = TranscriptionOptions(
|
||||
beam_size=beam_size,
|
||||
best_of=best_of,
|
||||
patience=patience,
|
||||
@@ -213,7 +213,7 @@ class WhisperModel:
|
||||
max_initial_timestamp=max_initial_timestamp,
|
||||
)
|
||||
|
||||
segments = self.generate_segments(features, options)
|
||||
segments = self.generate_segments(features, tokenizer, options)
|
||||
|
||||
audio_info = AudioInfo(
|
||||
language=language,
|
||||
@@ -222,7 +222,7 @@ class WhisperModel:
|
||||
|
||||
return segments, audio_info
|
||||
|
||||
def generate_segments(self, features, options):
|
||||
def generate_segments(self, features, tokenizer, options):
|
||||
content_frames = features.shape[-1] - self.feature_extractor.nb_max_frames
|
||||
seek = 0
|
||||
all_tokens = []
|
||||
@@ -230,7 +230,7 @@ class WhisperModel:
|
||||
|
||||
if options.initial_prompt is not None:
|
||||
initial_prompt = " " + options.initial_prompt.strip()
|
||||
initial_prompt_tokens = self.encode_text(initial_prompt)
|
||||
initial_prompt_tokens = tokenizer.encode(initial_prompt)
|
||||
all_tokens.extend(initial_prompt_tokens)
|
||||
|
||||
while seek < content_frames:
|
||||
@@ -243,15 +243,14 @@ class WhisperModel:
|
||||
|
||||
previous_tokens = all_tokens[prompt_reset_since:]
|
||||
prompt = self.get_prompt(
|
||||
options.language,
|
||||
tokenizer,
|
||||
previous_tokens,
|
||||
task=options.task,
|
||||
without_timestamps=options.without_timestamps,
|
||||
prefix=options.prefix,
|
||||
)
|
||||
|
||||
result, avg_log_prob, temperature = self.generate_with_fallback(
|
||||
segment, prompt, options
|
||||
segment, prompt, tokenizer, options
|
||||
)
|
||||
|
||||
if options.no_speech_threshold is not None:
|
||||
@@ -276,16 +275,16 @@ class WhisperModel:
|
||||
|
||||
single_timestamp_ending = (
|
||||
len(tokens) >= 2
|
||||
and tokens[-2] < self.timestamp_begin_id
|
||||
and tokens[-1] >= self.timestamp_begin_id
|
||||
and tokens[-2] < tokenizer.timestamp_begin
|
||||
and tokens[-1] >= tokenizer.timestamp_begin
|
||||
)
|
||||
|
||||
consecutive_timestamps = [
|
||||
i
|
||||
for i in range(len(tokens))
|
||||
if i > 0
|
||||
and tokens[i] >= self.timestamp_begin_id
|
||||
and tokens[i - 1] >= self.timestamp_begin_id
|
||||
and tokens[i] >= tokenizer.timestamp_begin
|
||||
and tokens[i - 1] >= tokenizer.timestamp_begin
|
||||
]
|
||||
|
||||
if len(consecutive_timestamps) > 0:
|
||||
@@ -297,9 +296,11 @@ class WhisperModel:
|
||||
for current_slice in slices:
|
||||
sliced_tokens = tokens[last_slice:current_slice]
|
||||
start_timestamp_position = (
|
||||
sliced_tokens[0] - self.timestamp_begin_id
|
||||
sliced_tokens[0] - tokenizer.timestamp_begin
|
||||
)
|
||||
end_timestamp_position = (
|
||||
sliced_tokens[-1] - tokenizer.timestamp_begin
|
||||
)
|
||||
end_timestamp_position = sliced_tokens[-1] - self.timestamp_begin_id
|
||||
start_time = (
|
||||
time_offset + start_timestamp_position * self.time_precision
|
||||
)
|
||||
@@ -318,17 +319,17 @@ class WhisperModel:
|
||||
else:
|
||||
# otherwise, ignore the unfinished segment and seek to the last timestamp
|
||||
last_timestamp_position = (
|
||||
tokens[last_slice - 1] - self.timestamp_begin_id
|
||||
tokens[last_slice - 1] - tokenizer.timestamp_begin
|
||||
)
|
||||
seek += last_timestamp_position * self.input_stride
|
||||
|
||||
else:
|
||||
duration = segment_duration
|
||||
timestamps = [
|
||||
token for token in tokens if token >= self.timestamp_begin_id
|
||||
token for token in tokens if token >= tokenizer.timestamp_begin
|
||||
]
|
||||
if len(timestamps) > 0 and timestamps[-1] != self.timestamp_begin_id:
|
||||
last_timestamp_position = timestamps[-1] - self.timestamp_begin_id
|
||||
if len(timestamps) > 0 and timestamps[-1] != tokenizer.timestamp_begin:
|
||||
last_timestamp_position = timestamps[-1] - tokenizer.timestamp_begin
|
||||
duration = last_timestamp_position * self.time_precision
|
||||
|
||||
current_segments.append(
|
||||
@@ -344,7 +345,7 @@ class WhisperModel:
|
||||
tokens = segment["tokens"]
|
||||
all_tokens.extend(tokens)
|
||||
|
||||
text = self.decode_text_tokens(tokens)
|
||||
text = tokenizer.decode(tokens)
|
||||
if not text.strip():
|
||||
continue
|
||||
|
||||
@@ -354,14 +355,7 @@ class WhisperModel:
|
||||
text=text,
|
||||
)
|
||||
|
||||
def encode_text(self, text):
|
||||
return self.tokenizer.encode(text, add_special_tokens=False).ids
|
||||
|
||||
def decode_text_tokens(self, tokens):
|
||||
text_tokens = [token for token in tokens if token < self.eot_id]
|
||||
return self.tokenizer.decode(text_tokens)
|
||||
|
||||
def generate_with_fallback(self, segment, prompt, options):
|
||||
def generate_with_fallback(self, segment, prompt, tokenizer, options):
|
||||
features = self.get_input(segment)
|
||||
result = None
|
||||
avg_log_prob = None
|
||||
@@ -406,7 +400,7 @@ class WhisperModel:
|
||||
cum_log_prob = result.scores[0] * (seq_len**options.length_penalty)
|
||||
avg_log_prob = cum_log_prob / (seq_len + 1)
|
||||
|
||||
text = self.decode_text_tokens(tokens).strip()
|
||||
text = tokenizer.decode(tokens).strip()
|
||||
compression_ratio = get_compression_ratio(text)
|
||||
|
||||
needs_fallback = False
|
||||
@@ -430,33 +424,24 @@ class WhisperModel:
|
||||
|
||||
def get_prompt(
|
||||
self,
|
||||
language,
|
||||
tokenizer,
|
||||
previous_tokens,
|
||||
task="transcribe",
|
||||
without_timestamps=False,
|
||||
prefix=None,
|
||||
):
|
||||
prompt = []
|
||||
|
||||
if previous_tokens:
|
||||
prompt.append(self.tokenizer.token_to_id("<|startofprev|>"))
|
||||
prompt.append(tokenizer.sot_prev)
|
||||
prompt.extend(previous_tokens[-(self.max_length // 2 - 1) :])
|
||||
|
||||
prompt.append(self.tokenizer.token_to_id("<|startoftranscript|>"))
|
||||
|
||||
if self.model.is_multilingual:
|
||||
prompt.extend(
|
||||
[
|
||||
self.tokenizer.token_to_id("<|%s|>" % language),
|
||||
self.tokenizer.token_to_id("<|%s|>" % task),
|
||||
]
|
||||
)
|
||||
prompt.extend(tokenizer.sot_sequence)
|
||||
|
||||
if without_timestamps:
|
||||
prompt.append(self.tokenizer.token_to_id("<|notimestamps|>"))
|
||||
prompt.append(tokenizer.no_timestamps)
|
||||
|
||||
if prefix:
|
||||
prefix_tokens = self.encode_text(" " + prefix.strip())
|
||||
prefix_tokens = tokenizer.encode(" " + prefix.strip())
|
||||
if len(prefix_tokens) >= self.max_length // 2:
|
||||
prefix_tokens = prefix_tokens[: self.max_length // 2 - 1]
|
||||
prompt.extend(prefix_tokens)
|
||||
|
||||
Reference in New Issue
Block a user