diff --git a/faster_whisper/transcribe.py b/faster_whisper/transcribe.py index be42524..81a9736 100644 --- a/faster_whisper/transcribe.py +++ b/faster_whisper/transcribe.py @@ -24,6 +24,7 @@ class TranscriptionOptions( collections.namedtuple( "TranscriptionOptions", ( + "task", "beam_size", "best_of", "patience", @@ -93,6 +94,7 @@ class WhisperModel: self, input_file, language=None, + task="transcribe", beam_size=5, best_of=5, patience=1, @@ -110,6 +112,7 @@ class WhisperModel: input_file: Path to the input file or a file-like object. language: The language spoken in the audio. If not set, the language will be detected in the first 30 seconds of audio. + task: Task to execute (transcribe or translate). beam_size: Beam size to use for decoding. best_of: Number of candidates when sampling with non-zero temperature. patience: Beam search patience factor. @@ -151,6 +154,7 @@ class WhisperModel: language_probability = 1 options = TranscriptionOptions( + task=task, beam_size=beam_size, best_of=best_of, patience=patience, @@ -210,8 +214,12 @@ class WhisperModel: previous_tokens = all_tokens[prompt_reset_since:] prompt = self.get_prompt( - language, previous_tokens, without_timestamps=options.without_timestamps + language, + previous_tokens, + task=options.task, + without_timestamps=options.without_timestamps, ) + result, temperature = self.generate_with_fallback(segment, prompt, options) if ( @@ -323,7 +331,13 @@ class WhisperModel: return result, final_temperature - def get_prompt(self, language, previous_tokens, without_timestamps=False): + def get_prompt( + self, + language, + previous_tokens, + task="transcribe", + without_timestamps=False, + ): prompt = [] if previous_tokens: @@ -333,7 +347,7 @@ class WhisperModel: prompt += [ self.tokenizer.token_to_id("<|startoftranscript|>"), self.tokenizer.token_to_id("<|%s|>" % language), - self.tokenizer.token_to_id("<|transcribe|>"), + self.tokenizer.token_to_id("<|%s|>" % task), ] if without_timestamps: