Support English-only models

This commit is contained in:
Guillaume Klein
2023-02-16 17:02:40 +01:00
parent cda834c8ea
commit 123d9a5704

View File

@@ -1,5 +1,4 @@
import collections
import os
import zlib
import ctranslate2
@@ -78,21 +77,10 @@ class WhisperModel:
inter_threads=num_workers,
)
with open(os.path.join(model_path, "vocabulary.txt")) as vocab_file:
vocab_size = sum(1 for _ in vocab_file)
is_multilingual = vocab_size == 51865
if not is_multilingual:
raise NotImplementedError(
"English-only models are currently not supported. "
"The underlying CTranslate2 implementation makes some assumptions about "
"the prompt format that are not compatible with English-only models. "
"This will be improved in a future version. "
"Please use a multilingual model for now."
)
self.feature_extractor = FeatureExtractor()
self.tokenizer = tokenizers.Tokenizer.from_pretrained("openai/whisper-tiny")
self.tokenizer = tokenizers.Tokenizer.from_pretrained(
"openai/whisper-tiny" + ("" if self.model.is_multilingual else ".en")
)
self.eot_id = self.tokenizer.token_to_id("<|endoftext|>")
self.timestamp_begin_id = self.tokenizer.token_to_id("<|notimestamps|>") + 1
self.input_stride = 2
@@ -154,6 +142,10 @@ class WhisperModel:
features = self.feature_extractor(audio)
if language is None:
if not self.model.is_multilingual:
language = "en"
language_probability = 1
else:
segment = self.get_segment(features)
input = self.get_input(segment)
results = self.model.detect_language(input)
@@ -353,11 +345,15 @@ class WhisperModel:
prompt.append(self.tokenizer.token_to_id("<|startofprev|>"))
prompt.extend(previous_tokens[-(self.max_length // 2 - 1) :])
prompt += [
self.tokenizer.token_to_id("<|startoftranscript|>"),
prompt.append(self.tokenizer.token_to_id("<|startoftranscript|>"))
if self.model.is_multilingual:
prompt.extend(
[
self.tokenizer.token_to_id("<|%s|>" % language),
self.tokenizer.token_to_id("<|%s|>" % task),
]
)
if without_timestamps:
prompt.append(self.tokenizer.token_to_id("<|notimestamps|>"))