Support English-only models
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
import collections
|
||||
import os
|
||||
import zlib
|
||||
|
||||
import ctranslate2
|
||||
@@ -78,21 +77,10 @@ class WhisperModel:
|
||||
inter_threads=num_workers,
|
||||
)
|
||||
|
||||
with open(os.path.join(model_path, "vocabulary.txt")) as vocab_file:
|
||||
vocab_size = sum(1 for _ in vocab_file)
|
||||
|
||||
is_multilingual = vocab_size == 51865
|
||||
if not is_multilingual:
|
||||
raise NotImplementedError(
|
||||
"English-only models are currently not supported. "
|
||||
"The underlying CTranslate2 implementation makes some assumptions about "
|
||||
"the prompt format that are not compatible with English-only models. "
|
||||
"This will be improved in a future version. "
|
||||
"Please use a multilingual model for now."
|
||||
)
|
||||
|
||||
self.feature_extractor = FeatureExtractor()
|
||||
self.tokenizer = tokenizers.Tokenizer.from_pretrained("openai/whisper-tiny")
|
||||
self.tokenizer = tokenizers.Tokenizer.from_pretrained(
|
||||
"openai/whisper-tiny" + ("" if self.model.is_multilingual else ".en")
|
||||
)
|
||||
self.eot_id = self.tokenizer.token_to_id("<|endoftext|>")
|
||||
self.timestamp_begin_id = self.tokenizer.token_to_id("<|notimestamps|>") + 1
|
||||
self.input_stride = 2
|
||||
@@ -154,11 +142,15 @@ class WhisperModel:
|
||||
features = self.feature_extractor(audio)
|
||||
|
||||
if language is None:
|
||||
segment = self.get_segment(features)
|
||||
input = self.get_input(segment)
|
||||
results = self.model.detect_language(input)
|
||||
language_token, language_probability = results[0][0]
|
||||
language = language_token[2:-2]
|
||||
if not self.model.is_multilingual:
|
||||
language = "en"
|
||||
language_probability = 1
|
||||
else:
|
||||
segment = self.get_segment(features)
|
||||
input = self.get_input(segment)
|
||||
results = self.model.detect_language(input)
|
||||
language_token, language_probability = results[0][0]
|
||||
language = language_token[2:-2]
|
||||
else:
|
||||
language_probability = 1
|
||||
|
||||
@@ -353,11 +345,15 @@ class WhisperModel:
|
||||
prompt.append(self.tokenizer.token_to_id("<|startofprev|>"))
|
||||
prompt.extend(previous_tokens[-(self.max_length // 2 - 1) :])
|
||||
|
||||
prompt += [
|
||||
self.tokenizer.token_to_id("<|startoftranscript|>"),
|
||||
self.tokenizer.token_to_id("<|%s|>" % language),
|
||||
self.tokenizer.token_to_id("<|%s|>" % task),
|
||||
]
|
||||
prompt.append(self.tokenizer.token_to_id("<|startoftranscript|>"))
|
||||
|
||||
if self.model.is_multilingual:
|
||||
prompt.extend(
|
||||
[
|
||||
self.tokenizer.token_to_id("<|%s|>" % language),
|
||||
self.tokenizer.token_to_id("<|%s|>" % task),
|
||||
]
|
||||
)
|
||||
|
||||
if without_timestamps:
|
||||
prompt.append(self.tokenizer.token_to_id("<|notimestamps|>"))
|
||||
|
||||
Reference in New Issue
Block a user