From 7d1d0541c8f69374d428b6cc4e4e4da36c42ac19 Mon Sep 17 00:00:00 2001 From: Guillaume Klein Date: Sun, 12 Feb 2023 11:42:21 +0100 Subject: [PATCH] Add the initial_prompt parameter (#2) * Add the initial_prompt parameter * Add docstring --- faster_whisper/transcribe.py | 39 ++++++++++++++++++------------------ 1 file changed, 20 insertions(+), 19 deletions(-) diff --git a/faster_whisper/transcribe.py b/faster_whisper/transcribe.py index 99f9205..349adb8 100644 --- a/faster_whisper/transcribe.py +++ b/faster_whisper/transcribe.py @@ -32,6 +32,7 @@ class TranscriptionOptions( "compression_ratio_threshold", "condition_on_previous_text", "temperatures", + "initial_prompt", ), ) ): @@ -64,16 +65,9 @@ class WhisperModel: ) self.feature_extractor = FeatureExtractor() - self.decoder = tokenizers.decoders.ByteLevel() - - with open(os.path.join(model_path, "vocabulary.txt")) as vocab_file: - self.ids_to_tokens = [line.rstrip("\n") for line in vocab_file] - self.tokens_to_ids = { - token: i for i, token in enumerate(self.ids_to_tokens) - } - - self.eot_id = self.tokens_to_ids["<|endoftext|>"] - self.timestamp_begin_id = self.tokens_to_ids["<|notimestamps|>"] + 1 + self.tokenizer = tokenizers.Tokenizer.from_pretrained("openai/whisper-tiny") + self.eot_id = self.tokenizer.token_to_id("<|endoftext|>") + self.timestamp_begin_id = self.tokenizer.token_to_id("<|notimestamps|>") + 1 self.input_stride = 2 self.time_precision = 0.02 self.max_length = 448 @@ -90,6 +84,7 @@ class WhisperModel: log_prob_threshold=-1.0, no_speech_threshold=0.6, condition_on_previous_text=True, + initial_prompt=None, ): """Transcribes an input file. @@ -114,6 +109,7 @@ class WhisperModel: as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop, such as repetition looping or timestamps going out of sync. + initial_prompt: Optional text to provide as a prompt for the first window. Returns: A tuple with: @@ -146,6 +142,7 @@ class WhisperModel: temperatures=( temperature if isinstance(temperature, (list, tuple)) else [temperature] ), + initial_prompt=initial_prompt, ) segments = self.generate_segments(features, language, options) @@ -179,6 +176,13 @@ class WhisperModel: all_tokens = [] prompt_reset_since = 0 + if options.initial_prompt is not None: + initial_prompt = " " + options.initial_prompt.strip() + initial_prompt_tokens = self.tokenizer.encode( + initial_prompt, add_special_tokens=False + ) + all_tokens.extend(initial_prompt_tokens.ids) + while offset < num_frames: time_offset = offset * self.feature_extractor.time_per_frame segment = self.get_segment(features, offset) @@ -253,11 +257,8 @@ class WhisperModel: prompt_reset_since = len(all_tokens) def decode_text_tokens(self, tokens): - text_tokens = [ - self.ids_to_tokens[token] for token in tokens if token < self.eot_id - ] - - return self.decoder.decode(text_tokens) + text_tokens = [token for token in tokens if token < self.eot_id] + return self.tokenizer.decode(text_tokens) def generate_with_fallback(self, segment, prompt, options): features = self.get_input(segment) @@ -304,13 +305,13 @@ class WhisperModel: prompt = [] if previous_tokens: - prompt.append(self.tokens_to_ids["<|startofprev|>"]) + prompt.append(self.tokenizer.token_to_id("<|startofprev|>")) prompt.extend(previous_tokens[-(self.max_length // 2 - 1) :]) prompt += [ - self.tokens_to_ids["<|startoftranscript|>"], - self.tokens_to_ids["<|%s|>" % language], - self.tokens_to_ids["<|transcribe|>"], + self.tokenizer.token_to_id("<|startoftranscript|>"), + self.tokenizer.token_to_id("<|%s|>" % language), + self.tokenizer.token_to_id("<|transcribe|>"), ] return prompt