Simplify reuse of the encoder output
This commit is contained in:
@@ -205,7 +205,7 @@ class WhisperModel:
|
||||
duration = audio.shape[0] / self.feature_extractor.sampling_rate
|
||||
features = self.feature_extractor(audio)
|
||||
|
||||
whisper_encoder = WhisperEncoder(self.model)
|
||||
encoder_output = None
|
||||
|
||||
if language is None:
|
||||
if not self.model.is_multilingual:
|
||||
@@ -213,7 +213,7 @@ class WhisperModel:
|
||||
language_probability = 1
|
||||
else:
|
||||
segment = features[:, : self.feature_extractor.nb_max_frames]
|
||||
encoder_output = whisper_encoder(0, segment)
|
||||
encoder_output = self.encode(segment)
|
||||
results = self.model.detect_language(encoder_output)
|
||||
language_token, language_probability = results[0][0]
|
||||
language = language_token[2:-2]
|
||||
@@ -250,7 +250,7 @@ class WhisperModel:
|
||||
append_punctuations=append_punctuations,
|
||||
)
|
||||
|
||||
segments = self.generate_segments(features, whisper_encoder, tokenizer, options)
|
||||
segments = self.generate_segments(features, tokenizer, options, encoder_output)
|
||||
|
||||
audio_info = AudioInfo(
|
||||
language=language,
|
||||
@@ -263,9 +263,9 @@ class WhisperModel:
|
||||
def generate_segments(
|
||||
self,
|
||||
features: np.ndarray,
|
||||
whisper_encoder: "WhisperEncoder",
|
||||
tokenizer: Tokenizer,
|
||||
options: TranscriptionOptions,
|
||||
encoder_output: Optional[ctranslate2.StorageView] = None,
|
||||
) -> Iterable[Segment]:
|
||||
content_frames = features.shape[-1] - self.feature_extractor.nb_max_frames
|
||||
seek = 0
|
||||
@@ -293,7 +293,8 @@ class WhisperModel:
|
||||
prefix=options.prefix,
|
||||
)
|
||||
|
||||
encoder_output = whisper_encoder(seek, segment)
|
||||
if encoder_output is None:
|
||||
encoder_output = self.encode(segment)
|
||||
|
||||
result, avg_log_prob, temperature = self.generate_with_fallback(
|
||||
encoder_output, prompt, tokenizer, options
|
||||
@@ -420,6 +421,8 @@ class WhisperModel:
|
||||
if seek_shift > 0:
|
||||
seek = previous_seek + seek_shift
|
||||
|
||||
encoder_output = None
|
||||
|
||||
for segment in current_segments:
|
||||
tokens = segment["tokens"]
|
||||
text = tokenizer.decode(tokens)
|
||||
@@ -440,6 +443,16 @@ class WhisperModel:
|
||||
),
|
||||
)
|
||||
|
||||
def encode(self, features: np.ndarray) -> ctranslate2.StorageView:
|
||||
# When the model is running on multiple GPUs, the encoder output should be moved
|
||||
# to the CPU since we don't know which GPU will handle the next job.
|
||||
to_cpu = self.model.device == "cuda" and len(self.model.device_index) > 1
|
||||
|
||||
features = np.expand_dims(features, 0)
|
||||
features = get_ctranslate2_storage(features)
|
||||
|
||||
return self.model.encode(features, to_cpu=to_cpu)
|
||||
|
||||
def generate_with_fallback(
|
||||
self,
|
||||
encoder_output: ctranslate2.StorageView,
|
||||
@@ -661,37 +674,6 @@ class WhisperModel:
|
||||
]
|
||||
|
||||
|
||||
class WhisperEncoder:
|
||||
"""Helper class to cache and reuse the encoder output."""
|
||||
|
||||
def __init__(self, model: ctranslate2.models.Whisper):
|
||||
self.model = model
|
||||
|
||||
# When the model is running on multiple GPUs, the encoder output should be moved
|
||||
# to the CPU since we don't know which GPU will handle the next job.
|
||||
self.cache_on_cpu = len(model.device_index) > 1
|
||||
|
||||
self.last_seek = -1
|
||||
self.last_output = None
|
||||
|
||||
def __call__(self, seek: int, features: np.ndarray) -> ctranslate2.StorageView:
|
||||
if self.last_seek == seek:
|
||||
return self.last_output
|
||||
|
||||
features = np.expand_dims(features, 0)
|
||||
features = get_ctranslate2_storage(features)
|
||||
|
||||
output = self.model.encode(features, to_cpu=self.cache_on_cpu)
|
||||
|
||||
if self.last_output is not None:
|
||||
del self.last_output
|
||||
|
||||
self.last_seek = seek
|
||||
self.last_output = output
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def get_ctranslate2_storage(segment: np.ndarray) -> ctranslate2.StorageView:
|
||||
segment = np.ascontiguousarray(segment)
|
||||
segment = ctranslate2.StorageView.from_array(segment)
|
||||
|
||||
Reference in New Issue
Block a user