From d03383f90230622d8e1a5796b3a13c54e12fcd62 Mon Sep 17 00:00:00 2001 From: Guillaume Klein Date: Thu, 30 Mar 2023 15:58:27 +0200 Subject: [PATCH] Simplify reuse of the encoder output --- faster_whisper/transcribe.py | 54 ++++++++++++------------------------ 1 file changed, 18 insertions(+), 36 deletions(-) diff --git a/faster_whisper/transcribe.py b/faster_whisper/transcribe.py index d716cb3..fb54f85 100644 --- a/faster_whisper/transcribe.py +++ b/faster_whisper/transcribe.py @@ -205,7 +205,7 @@ class WhisperModel: duration = audio.shape[0] / self.feature_extractor.sampling_rate features = self.feature_extractor(audio) - whisper_encoder = WhisperEncoder(self.model) + encoder_output = None if language is None: if not self.model.is_multilingual: @@ -213,7 +213,7 @@ class WhisperModel: language_probability = 1 else: segment = features[:, : self.feature_extractor.nb_max_frames] - encoder_output = whisper_encoder(0, segment) + encoder_output = self.encode(segment) results = self.model.detect_language(encoder_output) language_token, language_probability = results[0][0] language = language_token[2:-2] @@ -250,7 +250,7 @@ class WhisperModel: append_punctuations=append_punctuations, ) - segments = self.generate_segments(features, whisper_encoder, tokenizer, options) + segments = self.generate_segments(features, tokenizer, options, encoder_output) audio_info = AudioInfo( language=language, @@ -263,9 +263,9 @@ class WhisperModel: def generate_segments( self, features: np.ndarray, - whisper_encoder: "WhisperEncoder", tokenizer: Tokenizer, options: TranscriptionOptions, + encoder_output: Optional[ctranslate2.StorageView] = None, ) -> Iterable[Segment]: content_frames = features.shape[-1] - self.feature_extractor.nb_max_frames seek = 0 @@ -293,7 +293,8 @@ class WhisperModel: prefix=options.prefix, ) - encoder_output = whisper_encoder(seek, segment) + if encoder_output is None: + encoder_output = self.encode(segment) result, avg_log_prob, temperature = self.generate_with_fallback( encoder_output, prompt, tokenizer, options @@ -420,6 +421,8 @@ class WhisperModel: if seek_shift > 0: seek = previous_seek + seek_shift + encoder_output = None + for segment in current_segments: tokens = segment["tokens"] text = tokenizer.decode(tokens) @@ -440,6 +443,16 @@ class WhisperModel: ), ) + def encode(self, features: np.ndarray) -> ctranslate2.StorageView: + # When the model is running on multiple GPUs, the encoder output should be moved + # to the CPU since we don't know which GPU will handle the next job. + to_cpu = self.model.device == "cuda" and len(self.model.device_index) > 1 + + features = np.expand_dims(features, 0) + features = get_ctranslate2_storage(features) + + return self.model.encode(features, to_cpu=to_cpu) + def generate_with_fallback( self, encoder_output: ctranslate2.StorageView, @@ -661,37 +674,6 @@ class WhisperModel: ] -class WhisperEncoder: - """Helper class to cache and reuse the encoder output.""" - - def __init__(self, model: ctranslate2.models.Whisper): - self.model = model - - # When the model is running on multiple GPUs, the encoder output should be moved - # to the CPU since we don't know which GPU will handle the next job. - self.cache_on_cpu = len(model.device_index) > 1 - - self.last_seek = -1 - self.last_output = None - - def __call__(self, seek: int, features: np.ndarray) -> ctranslate2.StorageView: - if self.last_seek == seek: - return self.last_output - - features = np.expand_dims(features, 0) - features = get_ctranslate2_storage(features) - - output = self.model.encode(features, to_cpu=self.cache_on_cpu) - - if self.last_output is not None: - del self.last_output - - self.last_seek = seek - self.last_output = output - - return output - - def get_ctranslate2_storage(segment: np.ndarray) -> ctranslate2.StorageView: segment = np.ascontiguousarray(segment) segment = ctranslate2.StorageView.from_array(segment)