diff --git a/faster_whisper/transcribe.py b/faster_whisper/transcribe.py index 402fccd..f65f3d2 100644 --- a/faster_whisper/transcribe.py +++ b/faster_whisper/transcribe.py @@ -196,14 +196,16 @@ class WhisperModel: duration = audio.shape[0] / self.feature_extractor.sampling_rate features = self.feature_extractor(audio) + whisper_encoder = WhisperEncoder(self.model) + if language is None: if not self.model.is_multilingual: language = "en" language_probability = 1 else: segment = features[:, : self.feature_extractor.nb_max_frames] - input = get_ctranslate2_storage(segment) - results = self.model.detect_language(input) + encoder_output = whisper_encoder(0, segment) + results = self.model.detect_language(encoder_output) language_token, language_probability = results[0][0] language = language_token[2:-2] else: @@ -239,7 +241,7 @@ class WhisperModel: append_punctuations=append_punctuations, ) - segments = self.generate_segments(features, tokenizer, options) + segments = self.generate_segments(features, whisper_encoder, tokenizer, options) audio_info = AudioInfo( language=language, @@ -252,6 +254,7 @@ class WhisperModel: def generate_segments( self, features: np.ndarray, + whisper_encoder: "WhisperEncoder", tokenizer: Tokenizer, options: TranscriptionOptions, ) -> Iterable[Segment]: @@ -281,8 +284,10 @@ class WhisperModel: prefix=options.prefix, ) + encoder_output = whisper_encoder(seek, segment) + result, avg_log_prob, temperature = self.generate_with_fallback( - segment, prompt, tokenizer, options + encoder_output, prompt, tokenizer, options ) if options.no_speech_threshold is not None: @@ -388,7 +393,7 @@ class WhisperModel: self.add_word_timestamps( current_segments, tokenizer, - segment, + encoder_output, segment_size, options.prepend_punctuations, options.append_punctuations, @@ -428,12 +433,11 @@ class WhisperModel: def generate_with_fallback( self, - segment: np.ndarray, + encoder_output: ctranslate2.StorageView, prompt: List[int], tokenizer: Tokenizer, options: TranscriptionOptions, ) -> Tuple[ctranslate2.models.WhisperGenerationResult, float, float]: - features = get_ctranslate2_storage(segment) result = None avg_log_prob = None final_temperature = None @@ -458,7 +462,7 @@ class WhisperModel: final_temperature = temperature result = self.model.generate( - features, + encoder_output, [prompt], length_penalty=options.length_penalty, max_length=self.max_length, @@ -529,7 +533,7 @@ class WhisperModel: self, segments: List[dict], tokenizer: Tokenizer, - mel: np.ndarray, + encoder_output: ctranslate2.StorageView, num_frames: int, prepend_punctuations: str, append_punctuations: str, @@ -543,7 +547,9 @@ class WhisperModel: ] text_tokens = list(itertools.chain.from_iterable(text_tokens_per_segment)) - alignment = self.find_alignment(tokenizer, text_tokens, mel, num_frames) + alignment = self.find_alignment( + tokenizer, text_tokens, encoder_output, num_frames + ) merge_punctuations(alignment, prepend_punctuations, append_punctuations) time_offset = ( @@ -585,7 +591,7 @@ class WhisperModel: self, tokenizer: Tokenizer, text_tokens: List[int], - mel: np.ndarray, + encoder_output: ctranslate2.StorageView, num_frames: int, median_filter_width: int = 7, ) -> List[dict]: @@ -593,7 +599,7 @@ class WhisperModel: return [] result = self.model.align( - get_ctranslate2_storage(mel), + encoder_output, tokenizer.sot_sequence, [text_tokens], num_frames, @@ -646,9 +652,39 @@ class WhisperModel: ] +class WhisperEncoder: + """Helper class to cache and reuse the encoder output.""" + + def __init__(self, model: ctranslate2.models.Whisper): + self.model = model + + # When the model is running on multiple GPUs, the encoder output should be moved + # to the CPU since we don't know which GPU will handle the next job. + self.cache_on_cpu = len(model.device_index) > 1 + + self.last_seek = -1 + self.last_output = None + + def __call__(self, seek: int, features: np.ndarray) -> ctranslate2.StorageView: + if self.last_seek == seek: + return self.last_output + + features = np.expand_dims(features, 0) + features = get_ctranslate2_storage(features) + + output = self.model.encode(features, to_cpu=self.cache_on_cpu) + + if self.last_output is not None: + del self.last_output + + self.last_seek = seek + self.last_output = output + + return output + + def get_ctranslate2_storage(segment: np.ndarray) -> ctranslate2.StorageView: segment = np.ascontiguousarray(segment) - segment = np.expand_dims(segment, 0) segment = ctranslate2.StorageView.from_array(segment) return segment diff --git a/requirements.txt b/requirements.txt index 4f981a0..fdecf4d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,3 @@ av==10.* -ctranslate2>=3.9,<4 +ctranslate2>=3.10,<4 tokenizers==0.13.*