Simplify reuse of the encoder output
This commit is contained in:
@@ -205,7 +205,7 @@ class WhisperModel:
|
|||||||
duration = audio.shape[0] / self.feature_extractor.sampling_rate
|
duration = audio.shape[0] / self.feature_extractor.sampling_rate
|
||||||
features = self.feature_extractor(audio)
|
features = self.feature_extractor(audio)
|
||||||
|
|
||||||
whisper_encoder = WhisperEncoder(self.model)
|
encoder_output = None
|
||||||
|
|
||||||
if language is None:
|
if language is None:
|
||||||
if not self.model.is_multilingual:
|
if not self.model.is_multilingual:
|
||||||
@@ -213,7 +213,7 @@ class WhisperModel:
|
|||||||
language_probability = 1
|
language_probability = 1
|
||||||
else:
|
else:
|
||||||
segment = features[:, : self.feature_extractor.nb_max_frames]
|
segment = features[:, : self.feature_extractor.nb_max_frames]
|
||||||
encoder_output = whisper_encoder(0, segment)
|
encoder_output = self.encode(segment)
|
||||||
results = self.model.detect_language(encoder_output)
|
results = self.model.detect_language(encoder_output)
|
||||||
language_token, language_probability = results[0][0]
|
language_token, language_probability = results[0][0]
|
||||||
language = language_token[2:-2]
|
language = language_token[2:-2]
|
||||||
@@ -250,7 +250,7 @@ class WhisperModel:
|
|||||||
append_punctuations=append_punctuations,
|
append_punctuations=append_punctuations,
|
||||||
)
|
)
|
||||||
|
|
||||||
segments = self.generate_segments(features, whisper_encoder, tokenizer, options)
|
segments = self.generate_segments(features, tokenizer, options, encoder_output)
|
||||||
|
|
||||||
audio_info = AudioInfo(
|
audio_info = AudioInfo(
|
||||||
language=language,
|
language=language,
|
||||||
@@ -263,9 +263,9 @@ class WhisperModel:
|
|||||||
def generate_segments(
|
def generate_segments(
|
||||||
self,
|
self,
|
||||||
features: np.ndarray,
|
features: np.ndarray,
|
||||||
whisper_encoder: "WhisperEncoder",
|
|
||||||
tokenizer: Tokenizer,
|
tokenizer: Tokenizer,
|
||||||
options: TranscriptionOptions,
|
options: TranscriptionOptions,
|
||||||
|
encoder_output: Optional[ctranslate2.StorageView] = None,
|
||||||
) -> Iterable[Segment]:
|
) -> Iterable[Segment]:
|
||||||
content_frames = features.shape[-1] - self.feature_extractor.nb_max_frames
|
content_frames = features.shape[-1] - self.feature_extractor.nb_max_frames
|
||||||
seek = 0
|
seek = 0
|
||||||
@@ -293,7 +293,8 @@ class WhisperModel:
|
|||||||
prefix=options.prefix,
|
prefix=options.prefix,
|
||||||
)
|
)
|
||||||
|
|
||||||
encoder_output = whisper_encoder(seek, segment)
|
if encoder_output is None:
|
||||||
|
encoder_output = self.encode(segment)
|
||||||
|
|
||||||
result, avg_log_prob, temperature = self.generate_with_fallback(
|
result, avg_log_prob, temperature = self.generate_with_fallback(
|
||||||
encoder_output, prompt, tokenizer, options
|
encoder_output, prompt, tokenizer, options
|
||||||
@@ -420,6 +421,8 @@ class WhisperModel:
|
|||||||
if seek_shift > 0:
|
if seek_shift > 0:
|
||||||
seek = previous_seek + seek_shift
|
seek = previous_seek + seek_shift
|
||||||
|
|
||||||
|
encoder_output = None
|
||||||
|
|
||||||
for segment in current_segments:
|
for segment in current_segments:
|
||||||
tokens = segment["tokens"]
|
tokens = segment["tokens"]
|
||||||
text = tokenizer.decode(tokens)
|
text = tokenizer.decode(tokens)
|
||||||
@@ -440,6 +443,16 @@ class WhisperModel:
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def encode(self, features: np.ndarray) -> ctranslate2.StorageView:
|
||||||
|
# When the model is running on multiple GPUs, the encoder output should be moved
|
||||||
|
# to the CPU since we don't know which GPU will handle the next job.
|
||||||
|
to_cpu = self.model.device == "cuda" and len(self.model.device_index) > 1
|
||||||
|
|
||||||
|
features = np.expand_dims(features, 0)
|
||||||
|
features = get_ctranslate2_storage(features)
|
||||||
|
|
||||||
|
return self.model.encode(features, to_cpu=to_cpu)
|
||||||
|
|
||||||
def generate_with_fallback(
|
def generate_with_fallback(
|
||||||
self,
|
self,
|
||||||
encoder_output: ctranslate2.StorageView,
|
encoder_output: ctranslate2.StorageView,
|
||||||
@@ -661,37 +674,6 @@ class WhisperModel:
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class WhisperEncoder:
|
|
||||||
"""Helper class to cache and reuse the encoder output."""
|
|
||||||
|
|
||||||
def __init__(self, model: ctranslate2.models.Whisper):
|
|
||||||
self.model = model
|
|
||||||
|
|
||||||
# When the model is running on multiple GPUs, the encoder output should be moved
|
|
||||||
# to the CPU since we don't know which GPU will handle the next job.
|
|
||||||
self.cache_on_cpu = len(model.device_index) > 1
|
|
||||||
|
|
||||||
self.last_seek = -1
|
|
||||||
self.last_output = None
|
|
||||||
|
|
||||||
def __call__(self, seek: int, features: np.ndarray) -> ctranslate2.StorageView:
|
|
||||||
if self.last_seek == seek:
|
|
||||||
return self.last_output
|
|
||||||
|
|
||||||
features = np.expand_dims(features, 0)
|
|
||||||
features = get_ctranslate2_storage(features)
|
|
||||||
|
|
||||||
output = self.model.encode(features, to_cpu=self.cache_on_cpu)
|
|
||||||
|
|
||||||
if self.last_output is not None:
|
|
||||||
del self.last_output
|
|
||||||
|
|
||||||
self.last_seek = seek
|
|
||||||
self.last_output = output
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
def get_ctranslate2_storage(segment: np.ndarray) -> ctranslate2.StorageView:
|
def get_ctranslate2_storage(segment: np.ndarray) -> ctranslate2.StorageView:
|
||||||
segment = np.ascontiguousarray(segment)
|
segment = np.ascontiguousarray(segment)
|
||||||
segment = ctranslate2.StorageView.from_array(segment)
|
segment = ctranslate2.StorageView.from_array(segment)
|
||||||
|
|||||||
Reference in New Issue
Block a user