Run the encoder only once for each 30-second window (#73)
This commit is contained in:
@@ -196,14 +196,16 @@ class WhisperModel:
|
||||
duration = audio.shape[0] / self.feature_extractor.sampling_rate
|
||||
features = self.feature_extractor(audio)
|
||||
|
||||
whisper_encoder = WhisperEncoder(self.model)
|
||||
|
||||
if language is None:
|
||||
if not self.model.is_multilingual:
|
||||
language = "en"
|
||||
language_probability = 1
|
||||
else:
|
||||
segment = features[:, : self.feature_extractor.nb_max_frames]
|
||||
input = get_ctranslate2_storage(segment)
|
||||
results = self.model.detect_language(input)
|
||||
encoder_output = whisper_encoder(0, segment)
|
||||
results = self.model.detect_language(encoder_output)
|
||||
language_token, language_probability = results[0][0]
|
||||
language = language_token[2:-2]
|
||||
else:
|
||||
@@ -239,7 +241,7 @@ class WhisperModel:
|
||||
append_punctuations=append_punctuations,
|
||||
)
|
||||
|
||||
segments = self.generate_segments(features, tokenizer, options)
|
||||
segments = self.generate_segments(features, whisper_encoder, tokenizer, options)
|
||||
|
||||
audio_info = AudioInfo(
|
||||
language=language,
|
||||
@@ -252,6 +254,7 @@ class WhisperModel:
|
||||
def generate_segments(
|
||||
self,
|
||||
features: np.ndarray,
|
||||
whisper_encoder: "WhisperEncoder",
|
||||
tokenizer: Tokenizer,
|
||||
options: TranscriptionOptions,
|
||||
) -> Iterable[Segment]:
|
||||
@@ -281,8 +284,10 @@ class WhisperModel:
|
||||
prefix=options.prefix,
|
||||
)
|
||||
|
||||
encoder_output = whisper_encoder(seek, segment)
|
||||
|
||||
result, avg_log_prob, temperature = self.generate_with_fallback(
|
||||
segment, prompt, tokenizer, options
|
||||
encoder_output, prompt, tokenizer, options
|
||||
)
|
||||
|
||||
if options.no_speech_threshold is not None:
|
||||
@@ -388,7 +393,7 @@ class WhisperModel:
|
||||
self.add_word_timestamps(
|
||||
current_segments,
|
||||
tokenizer,
|
||||
segment,
|
||||
encoder_output,
|
||||
segment_size,
|
||||
options.prepend_punctuations,
|
||||
options.append_punctuations,
|
||||
@@ -428,12 +433,11 @@ class WhisperModel:
|
||||
|
||||
def generate_with_fallback(
|
||||
self,
|
||||
segment: np.ndarray,
|
||||
encoder_output: ctranslate2.StorageView,
|
||||
prompt: List[int],
|
||||
tokenizer: Tokenizer,
|
||||
options: TranscriptionOptions,
|
||||
) -> Tuple[ctranslate2.models.WhisperGenerationResult, float, float]:
|
||||
features = get_ctranslate2_storage(segment)
|
||||
result = None
|
||||
avg_log_prob = None
|
||||
final_temperature = None
|
||||
@@ -458,7 +462,7 @@ class WhisperModel:
|
||||
|
||||
final_temperature = temperature
|
||||
result = self.model.generate(
|
||||
features,
|
||||
encoder_output,
|
||||
[prompt],
|
||||
length_penalty=options.length_penalty,
|
||||
max_length=self.max_length,
|
||||
@@ -529,7 +533,7 @@ class WhisperModel:
|
||||
self,
|
||||
segments: List[dict],
|
||||
tokenizer: Tokenizer,
|
||||
mel: np.ndarray,
|
||||
encoder_output: ctranslate2.StorageView,
|
||||
num_frames: int,
|
||||
prepend_punctuations: str,
|
||||
append_punctuations: str,
|
||||
@@ -543,7 +547,9 @@ class WhisperModel:
|
||||
]
|
||||
|
||||
text_tokens = list(itertools.chain.from_iterable(text_tokens_per_segment))
|
||||
alignment = self.find_alignment(tokenizer, text_tokens, mel, num_frames)
|
||||
alignment = self.find_alignment(
|
||||
tokenizer, text_tokens, encoder_output, num_frames
|
||||
)
|
||||
merge_punctuations(alignment, prepend_punctuations, append_punctuations)
|
||||
|
||||
time_offset = (
|
||||
@@ -585,7 +591,7 @@ class WhisperModel:
|
||||
self,
|
||||
tokenizer: Tokenizer,
|
||||
text_tokens: List[int],
|
||||
mel: np.ndarray,
|
||||
encoder_output: ctranslate2.StorageView,
|
||||
num_frames: int,
|
||||
median_filter_width: int = 7,
|
||||
) -> List[dict]:
|
||||
@@ -593,7 +599,7 @@ class WhisperModel:
|
||||
return []
|
||||
|
||||
result = self.model.align(
|
||||
get_ctranslate2_storage(mel),
|
||||
encoder_output,
|
||||
tokenizer.sot_sequence,
|
||||
[text_tokens],
|
||||
num_frames,
|
||||
@@ -646,9 +652,39 @@ class WhisperModel:
|
||||
]
|
||||
|
||||
|
||||
class WhisperEncoder:
|
||||
"""Helper class to cache and reuse the encoder output."""
|
||||
|
||||
def __init__(self, model: ctranslate2.models.Whisper):
|
||||
self.model = model
|
||||
|
||||
# When the model is running on multiple GPUs, the encoder output should be moved
|
||||
# to the CPU since we don't know which GPU will handle the next job.
|
||||
self.cache_on_cpu = len(model.device_index) > 1
|
||||
|
||||
self.last_seek = -1
|
||||
self.last_output = None
|
||||
|
||||
def __call__(self, seek: int, features: np.ndarray) -> ctranslate2.StorageView:
|
||||
if self.last_seek == seek:
|
||||
return self.last_output
|
||||
|
||||
features = np.expand_dims(features, 0)
|
||||
features = get_ctranslate2_storage(features)
|
||||
|
||||
output = self.model.encode(features, to_cpu=self.cache_on_cpu)
|
||||
|
||||
if self.last_output is not None:
|
||||
del self.last_output
|
||||
|
||||
self.last_seek = seek
|
||||
self.last_output = output
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def get_ctranslate2_storage(segment: np.ndarray) -> ctranslate2.StorageView:
|
||||
segment = np.ascontiguousarray(segment)
|
||||
segment = np.expand_dims(segment, 0)
|
||||
segment = ctranslate2.StorageView.from_array(segment)
|
||||
return segment
|
||||
|
||||
|
||||
Reference in New Issue
Block a user