Simplify reuse of the encoder output

This commit is contained in:
Guillaume Klein
2023-03-30 15:58:27 +02:00
parent 39fddba886
commit d03383f902

View File

@@ -205,7 +205,7 @@ class WhisperModel:
duration = audio.shape[0] / self.feature_extractor.sampling_rate
features = self.feature_extractor(audio)
whisper_encoder = WhisperEncoder(self.model)
encoder_output = None
if language is None:
if not self.model.is_multilingual:
@@ -213,7 +213,7 @@ class WhisperModel:
language_probability = 1
else:
segment = features[:, : self.feature_extractor.nb_max_frames]
encoder_output = whisper_encoder(0, segment)
encoder_output = self.encode(segment)
results = self.model.detect_language(encoder_output)
language_token, language_probability = results[0][0]
language = language_token[2:-2]
@@ -250,7 +250,7 @@ class WhisperModel:
append_punctuations=append_punctuations,
)
segments = self.generate_segments(features, whisper_encoder, tokenizer, options)
segments = self.generate_segments(features, tokenizer, options, encoder_output)
audio_info = AudioInfo(
language=language,
@@ -263,9 +263,9 @@ class WhisperModel:
def generate_segments(
self,
features: np.ndarray,
whisper_encoder: "WhisperEncoder",
tokenizer: Tokenizer,
options: TranscriptionOptions,
encoder_output: Optional[ctranslate2.StorageView] = None,
) -> Iterable[Segment]:
content_frames = features.shape[-1] - self.feature_extractor.nb_max_frames
seek = 0
@@ -293,7 +293,8 @@ class WhisperModel:
prefix=options.prefix,
)
encoder_output = whisper_encoder(seek, segment)
if encoder_output is None:
encoder_output = self.encode(segment)
result, avg_log_prob, temperature = self.generate_with_fallback(
encoder_output, prompt, tokenizer, options
@@ -420,6 +421,8 @@ class WhisperModel:
if seek_shift > 0:
seek = previous_seek + seek_shift
encoder_output = None
for segment in current_segments:
tokens = segment["tokens"]
text = tokenizer.decode(tokens)
@@ -440,6 +443,16 @@ class WhisperModel:
),
)
def encode(self, features: np.ndarray) -> ctranslate2.StorageView:
# When the model is running on multiple GPUs, the encoder output should be moved
# to the CPU since we don't know which GPU will handle the next job.
to_cpu = self.model.device == "cuda" and len(self.model.device_index) > 1
features = np.expand_dims(features, 0)
features = get_ctranslate2_storage(features)
return self.model.encode(features, to_cpu=to_cpu)
def generate_with_fallback(
self,
encoder_output: ctranslate2.StorageView,
@@ -661,37 +674,6 @@ class WhisperModel:
]
class WhisperEncoder:
"""Helper class to cache and reuse the encoder output."""
def __init__(self, model: ctranslate2.models.Whisper):
self.model = model
# When the model is running on multiple GPUs, the encoder output should be moved
# to the CPU since we don't know which GPU will handle the next job.
self.cache_on_cpu = len(model.device_index) > 1
self.last_seek = -1
self.last_output = None
def __call__(self, seek: int, features: np.ndarray) -> ctranslate2.StorageView:
if self.last_seek == seek:
return self.last_output
features = np.expand_dims(features, 0)
features = get_ctranslate2_storage(features)
output = self.model.encode(features, to_cpu=self.cache_on_cpu)
if self.last_output is not None:
del self.last_output
self.last_seek = seek
self.last_output = output
return output
def get_ctranslate2_storage(segment: np.ndarray) -> ctranslate2.StorageView:
segment = np.ascontiguousarray(segment)
segment = ctranslate2.StorageView.from_array(segment)