diff --git a/faster_whisper/transcribe.py b/faster_whisper/transcribe.py index a4a7878..5d1d109 100644 --- a/faster_whisper/transcribe.py +++ b/faster_whisper/transcribe.py @@ -36,6 +36,7 @@ class TranscriptionOptions( "condition_on_previous_text", "temperatures", "initial_prompt", + "prefix", "without_timestamps", ), ) @@ -112,6 +113,7 @@ class WhisperModel: no_speech_threshold: Optional[float] = 0.6, condition_on_previous_text: bool = True, initial_prompt: Optional[str] = None, + prefix: Optional[str] = None, without_timestamps: bool = False, ): """Transcribes an input file. @@ -141,6 +143,7 @@ class WhisperModel: windows, but the model becomes less prone to getting stuck in a failure loop, such as repetition looping or timestamps going out of sync. initial_prompt: Optional text to provide as a prompt for the first window. + prefix: Optional text to provide as a prefix for the first window. without_timestamps: Only sample text tokens. Returns: @@ -183,6 +186,7 @@ class WhisperModel: temperature if isinstance(temperature, (list, tuple)) else [temperature] ), initial_prompt=initial_prompt, + prefix=prefix, without_timestamps=without_timestamps, ) @@ -219,10 +223,8 @@ class WhisperModel: if options.initial_prompt is not None: initial_prompt = " " + options.initial_prompt.strip() - initial_prompt_tokens = self.tokenizer.encode( - initial_prompt, add_special_tokens=False - ) - all_tokens.extend(initial_prompt_tokens.ids) + initial_prompt_tokens = self.encode_text(initial_prompt) + all_tokens.extend(initial_prompt_tokens) while offset < num_frames: time_offset = offset * self.feature_extractor.time_per_frame @@ -235,6 +237,7 @@ class WhisperModel: previous_tokens, task=options.task, without_timestamps=options.without_timestamps, + prefix=options.prefix, ) result, avg_log_prob, temperature = self.generate_with_fallback( @@ -314,6 +317,9 @@ class WhisperModel: if not options.condition_on_previous_text or temperature > 0.5: prompt_reset_since = len(all_tokens) + def encode_text(self, text): + return self.tokenizer.encode(text, add_special_tokens=False).ids + def decode_text_tokens(self, tokens): text_tokens = [token for token in tokens if token < self.eot_id] return self.tokenizer.decode(text_tokens) @@ -384,6 +390,7 @@ class WhisperModel: previous_tokens, task="transcribe", without_timestamps=False, + prefix=None, ): prompt = [] @@ -404,6 +411,12 @@ class WhisperModel: if without_timestamps: prompt.append(self.tokenizer.token_to_id("<|notimestamps|>")) + if prefix: + prefix_tokens = self.encode_text(" " + prefix.strip()) + if len(prefix_tokens) >= self.max_length // 2: + prefix_tokens = prefix_tokens[: self.max_length // 2 - 1] + prompt.extend(prefix_tokens) + return prompt def get_segment(self, features, offset=0):