Create a helper class Tokenizer
This commit is contained in:
69
faster_whisper/tokenizer.py
Normal file
69
faster_whisper/tokenizer.py
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
from functools import cached_property
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import tokenizers
|
||||||
|
|
||||||
|
|
||||||
|
class Tokenizer:
|
||||||
|
"""Simple wrapper around a tokenizers.Tokenizer."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
tokenizer: tokenizers.Tokenizer,
|
||||||
|
multilingual: bool,
|
||||||
|
task: Optional[str] = None,
|
||||||
|
language: Optional[str] = None,
|
||||||
|
):
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
|
if multilingual:
|
||||||
|
self.task = self.tokenizer.token_to_id("<|%s|>" % task)
|
||||||
|
if self.task is None:
|
||||||
|
raise ValueError("%s is not a valid task" % task)
|
||||||
|
|
||||||
|
self.language = self.tokenizer.token_to_id("<|%s|>" % language)
|
||||||
|
if self.language is None:
|
||||||
|
raise ValueError("%s is not a valid language code" % language)
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.task = None
|
||||||
|
self.language = None
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def sot(self) -> int:
|
||||||
|
return self.tokenizer.token_to_id("<|startoftranscript|>")
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def sot_prev(self) -> int:
|
||||||
|
return self.tokenizer.token_to_id("<|startofprev|>")
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def eot(self) -> int:
|
||||||
|
return self.tokenizer.token_to_id("<|endoftext|>")
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def no_timestamps(self) -> int:
|
||||||
|
return self.tokenizer.token_to_id("<|notimestamps|>")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def timestamp_begin(self) -> int:
|
||||||
|
return self.no_timestamps + 1
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sot_sequence(self) -> List[int]:
|
||||||
|
sequence = [self.sot]
|
||||||
|
|
||||||
|
if self.language is not None:
|
||||||
|
sequence.append(self.language)
|
||||||
|
|
||||||
|
if self.task is not None:
|
||||||
|
sequence.append(self.task)
|
||||||
|
|
||||||
|
return sequence
|
||||||
|
|
||||||
|
def encode(self, text: str) -> List[int]:
|
||||||
|
return self.tokenizer.encode(text, add_special_tokens=False).ids
|
||||||
|
|
||||||
|
def decode(self, tokens: List[int]) -> str:
|
||||||
|
text_tokens = [token for token in tokens if token < self.eot]
|
||||||
|
return self.tokenizer.decode(text_tokens)
|
||||||
@@ -10,6 +10,7 @@ import tokenizers
|
|||||||
|
|
||||||
from faster_whisper.audio import decode_audio
|
from faster_whisper.audio import decode_audio
|
||||||
from faster_whisper.feature_extractor import FeatureExtractor
|
from faster_whisper.feature_extractor import FeatureExtractor
|
||||||
|
from faster_whisper.tokenizer import Tokenizer
|
||||||
|
|
||||||
|
|
||||||
class Segment(collections.namedtuple("Segment", ("start", "end", "text"))):
|
class Segment(collections.namedtuple("Segment", ("start", "end", "text"))):
|
||||||
@@ -26,8 +27,6 @@ class TranscriptionOptions(
|
|||||||
collections.namedtuple(
|
collections.namedtuple(
|
||||||
"TranscriptionOptions",
|
"TranscriptionOptions",
|
||||||
(
|
(
|
||||||
"language",
|
|
||||||
"task",
|
|
||||||
"beam_size",
|
"beam_size",
|
||||||
"best_of",
|
"best_of",
|
||||||
"patience",
|
"patience",
|
||||||
@@ -88,15 +87,13 @@ class WhisperModel:
|
|||||||
|
|
||||||
tokenizer_file = os.path.join(model_path, "tokenizer.json")
|
tokenizer_file = os.path.join(model_path, "tokenizer.json")
|
||||||
if os.path.isfile(tokenizer_file):
|
if os.path.isfile(tokenizer_file):
|
||||||
self.tokenizer = tokenizers.Tokenizer.from_file(tokenizer_file)
|
self.hf_tokenizer = tokenizers.Tokenizer.from_file(tokenizer_file)
|
||||||
else:
|
else:
|
||||||
self.tokenizer = tokenizers.Tokenizer.from_pretrained(
|
self.hf_tokenizer = tokenizers.Tokenizer.from_pretrained(
|
||||||
"openai/whisper-tiny" + ("" if self.model.is_multilingual else ".en")
|
"openai/whisper-tiny" + ("" if self.model.is_multilingual else ".en")
|
||||||
)
|
)
|
||||||
|
|
||||||
self.feature_extractor = FeatureExtractor()
|
self.feature_extractor = FeatureExtractor()
|
||||||
self.eot_id = self.tokenizer.token_to_id("<|endoftext|>")
|
|
||||||
self.timestamp_begin_id = self.tokenizer.token_to_id("<|notimestamps|>") + 1
|
|
||||||
self.input_stride = 2
|
self.input_stride = 2
|
||||||
self.time_precision = 0.02
|
self.time_precision = 0.02
|
||||||
self.max_length = 448
|
self.max_length = 448
|
||||||
@@ -187,13 +184,16 @@ class WhisperModel:
|
|||||||
language_token, language_probability = results[0][0]
|
language_token, language_probability = results[0][0]
|
||||||
language = language_token[2:-2]
|
language = language_token[2:-2]
|
||||||
else:
|
else:
|
||||||
if self.tokenizer.token_to_id("<|%s|>" % language) is None:
|
|
||||||
raise ValueError("%s is not a valid language code" % language)
|
|
||||||
language_probability = 1
|
language_probability = 1
|
||||||
|
|
||||||
options = TranscriptionOptions(
|
tokenizer = Tokenizer(
|
||||||
language=language,
|
self.hf_tokenizer,
|
||||||
|
self.model.is_multilingual,
|
||||||
task=task,
|
task=task,
|
||||||
|
language=language,
|
||||||
|
)
|
||||||
|
|
||||||
|
options = TranscriptionOptions(
|
||||||
beam_size=beam_size,
|
beam_size=beam_size,
|
||||||
best_of=best_of,
|
best_of=best_of,
|
||||||
patience=patience,
|
patience=patience,
|
||||||
@@ -213,7 +213,7 @@ class WhisperModel:
|
|||||||
max_initial_timestamp=max_initial_timestamp,
|
max_initial_timestamp=max_initial_timestamp,
|
||||||
)
|
)
|
||||||
|
|
||||||
segments = self.generate_segments(features, options)
|
segments = self.generate_segments(features, tokenizer, options)
|
||||||
|
|
||||||
audio_info = AudioInfo(
|
audio_info = AudioInfo(
|
||||||
language=language,
|
language=language,
|
||||||
@@ -222,7 +222,7 @@ class WhisperModel:
|
|||||||
|
|
||||||
return segments, audio_info
|
return segments, audio_info
|
||||||
|
|
||||||
def generate_segments(self, features, options):
|
def generate_segments(self, features, tokenizer, options):
|
||||||
content_frames = features.shape[-1] - self.feature_extractor.nb_max_frames
|
content_frames = features.shape[-1] - self.feature_extractor.nb_max_frames
|
||||||
seek = 0
|
seek = 0
|
||||||
all_tokens = []
|
all_tokens = []
|
||||||
@@ -230,7 +230,7 @@ class WhisperModel:
|
|||||||
|
|
||||||
if options.initial_prompt is not None:
|
if options.initial_prompt is not None:
|
||||||
initial_prompt = " " + options.initial_prompt.strip()
|
initial_prompt = " " + options.initial_prompt.strip()
|
||||||
initial_prompt_tokens = self.encode_text(initial_prompt)
|
initial_prompt_tokens = tokenizer.encode(initial_prompt)
|
||||||
all_tokens.extend(initial_prompt_tokens)
|
all_tokens.extend(initial_prompt_tokens)
|
||||||
|
|
||||||
while seek < content_frames:
|
while seek < content_frames:
|
||||||
@@ -243,15 +243,14 @@ class WhisperModel:
|
|||||||
|
|
||||||
previous_tokens = all_tokens[prompt_reset_since:]
|
previous_tokens = all_tokens[prompt_reset_since:]
|
||||||
prompt = self.get_prompt(
|
prompt = self.get_prompt(
|
||||||
options.language,
|
tokenizer,
|
||||||
previous_tokens,
|
previous_tokens,
|
||||||
task=options.task,
|
|
||||||
without_timestamps=options.without_timestamps,
|
without_timestamps=options.without_timestamps,
|
||||||
prefix=options.prefix,
|
prefix=options.prefix,
|
||||||
)
|
)
|
||||||
|
|
||||||
result, avg_log_prob, temperature = self.generate_with_fallback(
|
result, avg_log_prob, temperature = self.generate_with_fallback(
|
||||||
segment, prompt, options
|
segment, prompt, tokenizer, options
|
||||||
)
|
)
|
||||||
|
|
||||||
if options.no_speech_threshold is not None:
|
if options.no_speech_threshold is not None:
|
||||||
@@ -276,16 +275,16 @@ class WhisperModel:
|
|||||||
|
|
||||||
single_timestamp_ending = (
|
single_timestamp_ending = (
|
||||||
len(tokens) >= 2
|
len(tokens) >= 2
|
||||||
and tokens[-2] < self.timestamp_begin_id
|
and tokens[-2] < tokenizer.timestamp_begin
|
||||||
and tokens[-1] >= self.timestamp_begin_id
|
and tokens[-1] >= tokenizer.timestamp_begin
|
||||||
)
|
)
|
||||||
|
|
||||||
consecutive_timestamps = [
|
consecutive_timestamps = [
|
||||||
i
|
i
|
||||||
for i in range(len(tokens))
|
for i in range(len(tokens))
|
||||||
if i > 0
|
if i > 0
|
||||||
and tokens[i] >= self.timestamp_begin_id
|
and tokens[i] >= tokenizer.timestamp_begin
|
||||||
and tokens[i - 1] >= self.timestamp_begin_id
|
and tokens[i - 1] >= tokenizer.timestamp_begin
|
||||||
]
|
]
|
||||||
|
|
||||||
if len(consecutive_timestamps) > 0:
|
if len(consecutive_timestamps) > 0:
|
||||||
@@ -297,9 +296,11 @@ class WhisperModel:
|
|||||||
for current_slice in slices:
|
for current_slice in slices:
|
||||||
sliced_tokens = tokens[last_slice:current_slice]
|
sliced_tokens = tokens[last_slice:current_slice]
|
||||||
start_timestamp_position = (
|
start_timestamp_position = (
|
||||||
sliced_tokens[0] - self.timestamp_begin_id
|
sliced_tokens[0] - tokenizer.timestamp_begin
|
||||||
|
)
|
||||||
|
end_timestamp_position = (
|
||||||
|
sliced_tokens[-1] - tokenizer.timestamp_begin
|
||||||
)
|
)
|
||||||
end_timestamp_position = sliced_tokens[-1] - self.timestamp_begin_id
|
|
||||||
start_time = (
|
start_time = (
|
||||||
time_offset + start_timestamp_position * self.time_precision
|
time_offset + start_timestamp_position * self.time_precision
|
||||||
)
|
)
|
||||||
@@ -318,17 +319,17 @@ class WhisperModel:
|
|||||||
else:
|
else:
|
||||||
# otherwise, ignore the unfinished segment and seek to the last timestamp
|
# otherwise, ignore the unfinished segment and seek to the last timestamp
|
||||||
last_timestamp_position = (
|
last_timestamp_position = (
|
||||||
tokens[last_slice - 1] - self.timestamp_begin_id
|
tokens[last_slice - 1] - tokenizer.timestamp_begin
|
||||||
)
|
)
|
||||||
seek += last_timestamp_position * self.input_stride
|
seek += last_timestamp_position * self.input_stride
|
||||||
|
|
||||||
else:
|
else:
|
||||||
duration = segment_duration
|
duration = segment_duration
|
||||||
timestamps = [
|
timestamps = [
|
||||||
token for token in tokens if token >= self.timestamp_begin_id
|
token for token in tokens if token >= tokenizer.timestamp_begin
|
||||||
]
|
]
|
||||||
if len(timestamps) > 0 and timestamps[-1] != self.timestamp_begin_id:
|
if len(timestamps) > 0 and timestamps[-1] != tokenizer.timestamp_begin:
|
||||||
last_timestamp_position = timestamps[-1] - self.timestamp_begin_id
|
last_timestamp_position = timestamps[-1] - tokenizer.timestamp_begin
|
||||||
duration = last_timestamp_position * self.time_precision
|
duration = last_timestamp_position * self.time_precision
|
||||||
|
|
||||||
current_segments.append(
|
current_segments.append(
|
||||||
@@ -344,7 +345,7 @@ class WhisperModel:
|
|||||||
tokens = segment["tokens"]
|
tokens = segment["tokens"]
|
||||||
all_tokens.extend(tokens)
|
all_tokens.extend(tokens)
|
||||||
|
|
||||||
text = self.decode_text_tokens(tokens)
|
text = tokenizer.decode(tokens)
|
||||||
if not text.strip():
|
if not text.strip():
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -354,14 +355,7 @@ class WhisperModel:
|
|||||||
text=text,
|
text=text,
|
||||||
)
|
)
|
||||||
|
|
||||||
def encode_text(self, text):
|
def generate_with_fallback(self, segment, prompt, tokenizer, options):
|
||||||
return self.tokenizer.encode(text, add_special_tokens=False).ids
|
|
||||||
|
|
||||||
def decode_text_tokens(self, tokens):
|
|
||||||
text_tokens = [token for token in tokens if token < self.eot_id]
|
|
||||||
return self.tokenizer.decode(text_tokens)
|
|
||||||
|
|
||||||
def generate_with_fallback(self, segment, prompt, options):
|
|
||||||
features = self.get_input(segment)
|
features = self.get_input(segment)
|
||||||
result = None
|
result = None
|
||||||
avg_log_prob = None
|
avg_log_prob = None
|
||||||
@@ -406,7 +400,7 @@ class WhisperModel:
|
|||||||
cum_log_prob = result.scores[0] * (seq_len**options.length_penalty)
|
cum_log_prob = result.scores[0] * (seq_len**options.length_penalty)
|
||||||
avg_log_prob = cum_log_prob / (seq_len + 1)
|
avg_log_prob = cum_log_prob / (seq_len + 1)
|
||||||
|
|
||||||
text = self.decode_text_tokens(tokens).strip()
|
text = tokenizer.decode(tokens).strip()
|
||||||
compression_ratio = get_compression_ratio(text)
|
compression_ratio = get_compression_ratio(text)
|
||||||
|
|
||||||
needs_fallback = False
|
needs_fallback = False
|
||||||
@@ -430,33 +424,24 @@ class WhisperModel:
|
|||||||
|
|
||||||
def get_prompt(
|
def get_prompt(
|
||||||
self,
|
self,
|
||||||
language,
|
tokenizer,
|
||||||
previous_tokens,
|
previous_tokens,
|
||||||
task="transcribe",
|
|
||||||
without_timestamps=False,
|
without_timestamps=False,
|
||||||
prefix=None,
|
prefix=None,
|
||||||
):
|
):
|
||||||
prompt = []
|
prompt = []
|
||||||
|
|
||||||
if previous_tokens:
|
if previous_tokens:
|
||||||
prompt.append(self.tokenizer.token_to_id("<|startofprev|>"))
|
prompt.append(tokenizer.sot_prev)
|
||||||
prompt.extend(previous_tokens[-(self.max_length // 2 - 1) :])
|
prompt.extend(previous_tokens[-(self.max_length // 2 - 1) :])
|
||||||
|
|
||||||
prompt.append(self.tokenizer.token_to_id("<|startoftranscript|>"))
|
prompt.extend(tokenizer.sot_sequence)
|
||||||
|
|
||||||
if self.model.is_multilingual:
|
|
||||||
prompt.extend(
|
|
||||||
[
|
|
||||||
self.tokenizer.token_to_id("<|%s|>" % language),
|
|
||||||
self.tokenizer.token_to_id("<|%s|>" % task),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
if without_timestamps:
|
if without_timestamps:
|
||||||
prompt.append(self.tokenizer.token_to_id("<|notimestamps|>"))
|
prompt.append(tokenizer.no_timestamps)
|
||||||
|
|
||||||
if prefix:
|
if prefix:
|
||||||
prefix_tokens = self.encode_text(" " + prefix.strip())
|
prefix_tokens = tokenizer.encode(" " + prefix.strip())
|
||||||
if len(prefix_tokens) >= self.max_length // 2:
|
if len(prefix_tokens) >= self.max_length // 2:
|
||||||
prefix_tokens = prefix_tokens[: self.max_length // 2 - 1]
|
prefix_tokens = prefix_tokens[: self.max_length // 2 - 1]
|
||||||
prompt.extend(prefix_tokens)
|
prompt.extend(prefix_tokens)
|
||||||
|
|||||||
Reference in New Issue
Block a user