Run the encoder only once for each 30-second window (#73)

This commit is contained in:
Guillaume Klein
2023-03-24 10:53:49 +01:00
committed by GitHub
parent 2b7be47041
commit 523ae2180f
2 changed files with 50 additions and 14 deletions

View File

@@ -196,14 +196,16 @@ class WhisperModel:
duration = audio.shape[0] / self.feature_extractor.sampling_rate duration = audio.shape[0] / self.feature_extractor.sampling_rate
features = self.feature_extractor(audio) features = self.feature_extractor(audio)
whisper_encoder = WhisperEncoder(self.model)
if language is None: if language is None:
if not self.model.is_multilingual: if not self.model.is_multilingual:
language = "en" language = "en"
language_probability = 1 language_probability = 1
else: else:
segment = features[:, : self.feature_extractor.nb_max_frames] segment = features[:, : self.feature_extractor.nb_max_frames]
input = get_ctranslate2_storage(segment) encoder_output = whisper_encoder(0, segment)
results = self.model.detect_language(input) results = self.model.detect_language(encoder_output)
language_token, language_probability = results[0][0] language_token, language_probability = results[0][0]
language = language_token[2:-2] language = language_token[2:-2]
else: else:
@@ -239,7 +241,7 @@ class WhisperModel:
append_punctuations=append_punctuations, append_punctuations=append_punctuations,
) )
segments = self.generate_segments(features, tokenizer, options) segments = self.generate_segments(features, whisper_encoder, tokenizer, options)
audio_info = AudioInfo( audio_info = AudioInfo(
language=language, language=language,
@@ -252,6 +254,7 @@ class WhisperModel:
def generate_segments( def generate_segments(
self, self,
features: np.ndarray, features: np.ndarray,
whisper_encoder: "WhisperEncoder",
tokenizer: Tokenizer, tokenizer: Tokenizer,
options: TranscriptionOptions, options: TranscriptionOptions,
) -> Iterable[Segment]: ) -> Iterable[Segment]:
@@ -281,8 +284,10 @@ class WhisperModel:
prefix=options.prefix, prefix=options.prefix,
) )
encoder_output = whisper_encoder(seek, segment)
result, avg_log_prob, temperature = self.generate_with_fallback( result, avg_log_prob, temperature = self.generate_with_fallback(
segment, prompt, tokenizer, options encoder_output, prompt, tokenizer, options
) )
if options.no_speech_threshold is not None: if options.no_speech_threshold is not None:
@@ -388,7 +393,7 @@ class WhisperModel:
self.add_word_timestamps( self.add_word_timestamps(
current_segments, current_segments,
tokenizer, tokenizer,
segment, encoder_output,
segment_size, segment_size,
options.prepend_punctuations, options.prepend_punctuations,
options.append_punctuations, options.append_punctuations,
@@ -428,12 +433,11 @@ class WhisperModel:
def generate_with_fallback( def generate_with_fallback(
self, self,
segment: np.ndarray, encoder_output: ctranslate2.StorageView,
prompt: List[int], prompt: List[int],
tokenizer: Tokenizer, tokenizer: Tokenizer,
options: TranscriptionOptions, options: TranscriptionOptions,
) -> Tuple[ctranslate2.models.WhisperGenerationResult, float, float]: ) -> Tuple[ctranslate2.models.WhisperGenerationResult, float, float]:
features = get_ctranslate2_storage(segment)
result = None result = None
avg_log_prob = None avg_log_prob = None
final_temperature = None final_temperature = None
@@ -458,7 +462,7 @@ class WhisperModel:
final_temperature = temperature final_temperature = temperature
result = self.model.generate( result = self.model.generate(
features, encoder_output,
[prompt], [prompt],
length_penalty=options.length_penalty, length_penalty=options.length_penalty,
max_length=self.max_length, max_length=self.max_length,
@@ -529,7 +533,7 @@ class WhisperModel:
self, self,
segments: List[dict], segments: List[dict],
tokenizer: Tokenizer, tokenizer: Tokenizer,
mel: np.ndarray, encoder_output: ctranslate2.StorageView,
num_frames: int, num_frames: int,
prepend_punctuations: str, prepend_punctuations: str,
append_punctuations: str, append_punctuations: str,
@@ -543,7 +547,9 @@ class WhisperModel:
] ]
text_tokens = list(itertools.chain.from_iterable(text_tokens_per_segment)) text_tokens = list(itertools.chain.from_iterable(text_tokens_per_segment))
alignment = self.find_alignment(tokenizer, text_tokens, mel, num_frames) alignment = self.find_alignment(
tokenizer, text_tokens, encoder_output, num_frames
)
merge_punctuations(alignment, prepend_punctuations, append_punctuations) merge_punctuations(alignment, prepend_punctuations, append_punctuations)
time_offset = ( time_offset = (
@@ -585,7 +591,7 @@ class WhisperModel:
self, self,
tokenizer: Tokenizer, tokenizer: Tokenizer,
text_tokens: List[int], text_tokens: List[int],
mel: np.ndarray, encoder_output: ctranslate2.StorageView,
num_frames: int, num_frames: int,
median_filter_width: int = 7, median_filter_width: int = 7,
) -> List[dict]: ) -> List[dict]:
@@ -593,7 +599,7 @@ class WhisperModel:
return [] return []
result = self.model.align( result = self.model.align(
get_ctranslate2_storage(mel), encoder_output,
tokenizer.sot_sequence, tokenizer.sot_sequence,
[text_tokens], [text_tokens],
num_frames, num_frames,
@@ -646,9 +652,39 @@ class WhisperModel:
] ]
class WhisperEncoder:
"""Helper class to cache and reuse the encoder output."""
def __init__(self, model: ctranslate2.models.Whisper):
self.model = model
# When the model is running on multiple GPUs, the encoder output should be moved
# to the CPU since we don't know which GPU will handle the next job.
self.cache_on_cpu = len(model.device_index) > 1
self.last_seek = -1
self.last_output = None
def __call__(self, seek: int, features: np.ndarray) -> ctranslate2.StorageView:
if self.last_seek == seek:
return self.last_output
features = np.expand_dims(features, 0)
features = get_ctranslate2_storage(features)
output = self.model.encode(features, to_cpu=self.cache_on_cpu)
if self.last_output is not None:
del self.last_output
self.last_seek = seek
self.last_output = output
return output
def get_ctranslate2_storage(segment: np.ndarray) -> ctranslate2.StorageView: def get_ctranslate2_storage(segment: np.ndarray) -> ctranslate2.StorageView:
segment = np.ascontiguousarray(segment) segment = np.ascontiguousarray(segment)
segment = np.expand_dims(segment, 0)
segment = ctranslate2.StorageView.from_array(segment) segment = ctranslate2.StorageView.from_array(segment)
return segment return segment

View File

@@ -1,3 +1,3 @@
av==10.* av==10.*
ctranslate2>=3.9,<4 ctranslate2>=3.10,<4
tokenizers==0.13.* tokenizers==0.13.*