From c52adaca90ab9db531c2fe95ad3509318da8f4e1 Mon Sep 17 00:00:00 2001 From: Guillaume Klein Date: Thu, 9 Mar 2023 12:53:49 +0100 Subject: [PATCH] Create a helper class Tokenizer --- faster_whisper/tokenizer.py | 69 +++++++++++++++++++++++++++++ faster_whisper/transcribe.py | 85 +++++++++++++++--------------------- 2 files changed, 104 insertions(+), 50 deletions(-) create mode 100644 faster_whisper/tokenizer.py diff --git a/faster_whisper/tokenizer.py b/faster_whisper/tokenizer.py new file mode 100644 index 0000000..32c5a39 --- /dev/null +++ b/faster_whisper/tokenizer.py @@ -0,0 +1,69 @@ +from functools import cached_property +from typing import List, Optional + +import tokenizers + + +class Tokenizer: + """Simple wrapper around a tokenizers.Tokenizer.""" + + def __init__( + self, + tokenizer: tokenizers.Tokenizer, + multilingual: bool, + task: Optional[str] = None, + language: Optional[str] = None, + ): + self.tokenizer = tokenizer + + if multilingual: + self.task = self.tokenizer.token_to_id("<|%s|>" % task) + if self.task is None: + raise ValueError("%s is not a valid task" % task) + + self.language = self.tokenizer.token_to_id("<|%s|>" % language) + if self.language is None: + raise ValueError("%s is not a valid language code" % language) + + else: + self.task = None + self.language = None + + @cached_property + def sot(self) -> int: + return self.tokenizer.token_to_id("<|startoftranscript|>") + + @cached_property + def sot_prev(self) -> int: + return self.tokenizer.token_to_id("<|startofprev|>") + + @cached_property + def eot(self) -> int: + return self.tokenizer.token_to_id("<|endoftext|>") + + @cached_property + def no_timestamps(self) -> int: + return self.tokenizer.token_to_id("<|notimestamps|>") + + @property + def timestamp_begin(self) -> int: + return self.no_timestamps + 1 + + @property + def sot_sequence(self) -> List[int]: + sequence = [self.sot] + + if self.language is not None: + sequence.append(self.language) + + if self.task is not None: + sequence.append(self.task) + + return sequence + + def encode(self, text: str) -> List[int]: + return self.tokenizer.encode(text, add_special_tokens=False).ids + + def decode(self, tokens: List[int]) -> str: + text_tokens = [token for token in tokens if token < self.eot] + return self.tokenizer.decode(text_tokens) diff --git a/faster_whisper/transcribe.py b/faster_whisper/transcribe.py index b149792..e27d3a8 100644 --- a/faster_whisper/transcribe.py +++ b/faster_whisper/transcribe.py @@ -10,6 +10,7 @@ import tokenizers from faster_whisper.audio import decode_audio from faster_whisper.feature_extractor import FeatureExtractor +from faster_whisper.tokenizer import Tokenizer class Segment(collections.namedtuple("Segment", ("start", "end", "text"))): @@ -26,8 +27,6 @@ class TranscriptionOptions( collections.namedtuple( "TranscriptionOptions", ( - "language", - "task", "beam_size", "best_of", "patience", @@ -88,15 +87,13 @@ class WhisperModel: tokenizer_file = os.path.join(model_path, "tokenizer.json") if os.path.isfile(tokenizer_file): - self.tokenizer = tokenizers.Tokenizer.from_file(tokenizer_file) + self.hf_tokenizer = tokenizers.Tokenizer.from_file(tokenizer_file) else: - self.tokenizer = tokenizers.Tokenizer.from_pretrained( + self.hf_tokenizer = tokenizers.Tokenizer.from_pretrained( "openai/whisper-tiny" + ("" if self.model.is_multilingual else ".en") ) self.feature_extractor = FeatureExtractor() - self.eot_id = self.tokenizer.token_to_id("<|endoftext|>") - self.timestamp_begin_id = self.tokenizer.token_to_id("<|notimestamps|>") + 1 self.input_stride = 2 self.time_precision = 0.02 self.max_length = 448 @@ -187,13 +184,16 @@ class WhisperModel: language_token, language_probability = results[0][0] language = language_token[2:-2] else: - if self.tokenizer.token_to_id("<|%s|>" % language) is None: - raise ValueError("%s is not a valid language code" % language) language_probability = 1 - options = TranscriptionOptions( - language=language, + tokenizer = Tokenizer( + self.hf_tokenizer, + self.model.is_multilingual, task=task, + language=language, + ) + + options = TranscriptionOptions( beam_size=beam_size, best_of=best_of, patience=patience, @@ -213,7 +213,7 @@ class WhisperModel: max_initial_timestamp=max_initial_timestamp, ) - segments = self.generate_segments(features, options) + segments = self.generate_segments(features, tokenizer, options) audio_info = AudioInfo( language=language, @@ -222,7 +222,7 @@ class WhisperModel: return segments, audio_info - def generate_segments(self, features, options): + def generate_segments(self, features, tokenizer, options): content_frames = features.shape[-1] - self.feature_extractor.nb_max_frames seek = 0 all_tokens = [] @@ -230,7 +230,7 @@ class WhisperModel: if options.initial_prompt is not None: initial_prompt = " " + options.initial_prompt.strip() - initial_prompt_tokens = self.encode_text(initial_prompt) + initial_prompt_tokens = tokenizer.encode(initial_prompt) all_tokens.extend(initial_prompt_tokens) while seek < content_frames: @@ -243,15 +243,14 @@ class WhisperModel: previous_tokens = all_tokens[prompt_reset_since:] prompt = self.get_prompt( - options.language, + tokenizer, previous_tokens, - task=options.task, without_timestamps=options.without_timestamps, prefix=options.prefix, ) result, avg_log_prob, temperature = self.generate_with_fallback( - segment, prompt, options + segment, prompt, tokenizer, options ) if options.no_speech_threshold is not None: @@ -276,16 +275,16 @@ class WhisperModel: single_timestamp_ending = ( len(tokens) >= 2 - and tokens[-2] < self.timestamp_begin_id - and tokens[-1] >= self.timestamp_begin_id + and tokens[-2] < tokenizer.timestamp_begin + and tokens[-1] >= tokenizer.timestamp_begin ) consecutive_timestamps = [ i for i in range(len(tokens)) if i > 0 - and tokens[i] >= self.timestamp_begin_id - and tokens[i - 1] >= self.timestamp_begin_id + and tokens[i] >= tokenizer.timestamp_begin + and tokens[i - 1] >= tokenizer.timestamp_begin ] if len(consecutive_timestamps) > 0: @@ -297,9 +296,11 @@ class WhisperModel: for current_slice in slices: sliced_tokens = tokens[last_slice:current_slice] start_timestamp_position = ( - sliced_tokens[0] - self.timestamp_begin_id + sliced_tokens[0] - tokenizer.timestamp_begin + ) + end_timestamp_position = ( + sliced_tokens[-1] - tokenizer.timestamp_begin ) - end_timestamp_position = sliced_tokens[-1] - self.timestamp_begin_id start_time = ( time_offset + start_timestamp_position * self.time_precision ) @@ -318,17 +319,17 @@ class WhisperModel: else: # otherwise, ignore the unfinished segment and seek to the last timestamp last_timestamp_position = ( - tokens[last_slice - 1] - self.timestamp_begin_id + tokens[last_slice - 1] - tokenizer.timestamp_begin ) seek += last_timestamp_position * self.input_stride else: duration = segment_duration timestamps = [ - token for token in tokens if token >= self.timestamp_begin_id + token for token in tokens if token >= tokenizer.timestamp_begin ] - if len(timestamps) > 0 and timestamps[-1] != self.timestamp_begin_id: - last_timestamp_position = timestamps[-1] - self.timestamp_begin_id + if len(timestamps) > 0 and timestamps[-1] != tokenizer.timestamp_begin: + last_timestamp_position = timestamps[-1] - tokenizer.timestamp_begin duration = last_timestamp_position * self.time_precision current_segments.append( @@ -344,7 +345,7 @@ class WhisperModel: tokens = segment["tokens"] all_tokens.extend(tokens) - text = self.decode_text_tokens(tokens) + text = tokenizer.decode(tokens) if not text.strip(): continue @@ -354,14 +355,7 @@ class WhisperModel: text=text, ) - def encode_text(self, text): - return self.tokenizer.encode(text, add_special_tokens=False).ids - - def decode_text_tokens(self, tokens): - text_tokens = [token for token in tokens if token < self.eot_id] - return self.tokenizer.decode(text_tokens) - - def generate_with_fallback(self, segment, prompt, options): + def generate_with_fallback(self, segment, prompt, tokenizer, options): features = self.get_input(segment) result = None avg_log_prob = None @@ -406,7 +400,7 @@ class WhisperModel: cum_log_prob = result.scores[0] * (seq_len**options.length_penalty) avg_log_prob = cum_log_prob / (seq_len + 1) - text = self.decode_text_tokens(tokens).strip() + text = tokenizer.decode(tokens).strip() compression_ratio = get_compression_ratio(text) needs_fallback = False @@ -430,33 +424,24 @@ class WhisperModel: def get_prompt( self, - language, + tokenizer, previous_tokens, - task="transcribe", without_timestamps=False, prefix=None, ): prompt = [] if previous_tokens: - prompt.append(self.tokenizer.token_to_id("<|startofprev|>")) + prompt.append(tokenizer.sot_prev) prompt.extend(previous_tokens[-(self.max_length // 2 - 1) :]) - prompt.append(self.tokenizer.token_to_id("<|startoftranscript|>")) - - if self.model.is_multilingual: - prompt.extend( - [ - self.tokenizer.token_to_id("<|%s|>" % language), - self.tokenizer.token_to_id("<|%s|>" % task), - ] - ) + prompt.extend(tokenizer.sot_sequence) if without_timestamps: - prompt.append(self.tokenizer.token_to_id("<|notimestamps|>")) + prompt.append(tokenizer.no_timestamps) if prefix: - prefix_tokens = self.encode_text(" " + prefix.strip()) + prefix_tokens = tokenizer.encode(" " + prefix.strip()) if len(prefix_tokens) >= self.max_length // 2: prefix_tokens = prefix_tokens[: self.max_length // 2 - 1] prompt.extend(prefix_tokens)