From 123d9a57041c86468bdf06d30c50560d4c2fb9e5 Mon Sep 17 00:00:00 2001 From: Guillaume Klein Date: Thu, 16 Feb 2023 17:02:40 +0100 Subject: [PATCH] Support English-only models --- faster_whisper/transcribe.py | 46 ++++++++++++++++-------------------- 1 file changed, 21 insertions(+), 25 deletions(-) diff --git a/faster_whisper/transcribe.py b/faster_whisper/transcribe.py index 4ec28d4..f589d85 100644 --- a/faster_whisper/transcribe.py +++ b/faster_whisper/transcribe.py @@ -1,5 +1,4 @@ import collections -import os import zlib import ctranslate2 @@ -78,21 +77,10 @@ class WhisperModel: inter_threads=num_workers, ) - with open(os.path.join(model_path, "vocabulary.txt")) as vocab_file: - vocab_size = sum(1 for _ in vocab_file) - - is_multilingual = vocab_size == 51865 - if not is_multilingual: - raise NotImplementedError( - "English-only models are currently not supported. " - "The underlying CTranslate2 implementation makes some assumptions about " - "the prompt format that are not compatible with English-only models. " - "This will be improved in a future version. " - "Please use a multilingual model for now." - ) - self.feature_extractor = FeatureExtractor() - self.tokenizer = tokenizers.Tokenizer.from_pretrained("openai/whisper-tiny") + self.tokenizer = tokenizers.Tokenizer.from_pretrained( + "openai/whisper-tiny" + ("" if self.model.is_multilingual else ".en") + ) self.eot_id = self.tokenizer.token_to_id("<|endoftext|>") self.timestamp_begin_id = self.tokenizer.token_to_id("<|notimestamps|>") + 1 self.input_stride = 2 @@ -154,11 +142,15 @@ class WhisperModel: features = self.feature_extractor(audio) if language is None: - segment = self.get_segment(features) - input = self.get_input(segment) - results = self.model.detect_language(input) - language_token, language_probability = results[0][0] - language = language_token[2:-2] + if not self.model.is_multilingual: + language = "en" + language_probability = 1 + else: + segment = self.get_segment(features) + input = self.get_input(segment) + results = self.model.detect_language(input) + language_token, language_probability = results[0][0] + language = language_token[2:-2] else: language_probability = 1 @@ -353,11 +345,15 @@ class WhisperModel: prompt.append(self.tokenizer.token_to_id("<|startofprev|>")) prompt.extend(previous_tokens[-(self.max_length // 2 - 1) :]) - prompt += [ - self.tokenizer.token_to_id("<|startoftranscript|>"), - self.tokenizer.token_to_id("<|%s|>" % language), - self.tokenizer.token_to_id("<|%s|>" % task), - ] + prompt.append(self.tokenizer.token_to_id("<|startoftranscript|>")) + + if self.model.is_multilingual: + prompt.extend( + [ + self.tokenizer.token_to_id("<|%s|>" % language), + self.tokenizer.token_to_id("<|%s|>" % task), + ] + ) if without_timestamps: prompt.append(self.tokenizer.token_to_id("<|notimestamps|>"))