Run the encoder only once for each 30-second window (#73)
This commit is contained in:
@@ -196,14 +196,16 @@ class WhisperModel:
|
|||||||
duration = audio.shape[0] / self.feature_extractor.sampling_rate
|
duration = audio.shape[0] / self.feature_extractor.sampling_rate
|
||||||
features = self.feature_extractor(audio)
|
features = self.feature_extractor(audio)
|
||||||
|
|
||||||
|
whisper_encoder = WhisperEncoder(self.model)
|
||||||
|
|
||||||
if language is None:
|
if language is None:
|
||||||
if not self.model.is_multilingual:
|
if not self.model.is_multilingual:
|
||||||
language = "en"
|
language = "en"
|
||||||
language_probability = 1
|
language_probability = 1
|
||||||
else:
|
else:
|
||||||
segment = features[:, : self.feature_extractor.nb_max_frames]
|
segment = features[:, : self.feature_extractor.nb_max_frames]
|
||||||
input = get_ctranslate2_storage(segment)
|
encoder_output = whisper_encoder(0, segment)
|
||||||
results = self.model.detect_language(input)
|
results = self.model.detect_language(encoder_output)
|
||||||
language_token, language_probability = results[0][0]
|
language_token, language_probability = results[0][0]
|
||||||
language = language_token[2:-2]
|
language = language_token[2:-2]
|
||||||
else:
|
else:
|
||||||
@@ -239,7 +241,7 @@ class WhisperModel:
|
|||||||
append_punctuations=append_punctuations,
|
append_punctuations=append_punctuations,
|
||||||
)
|
)
|
||||||
|
|
||||||
segments = self.generate_segments(features, tokenizer, options)
|
segments = self.generate_segments(features, whisper_encoder, tokenizer, options)
|
||||||
|
|
||||||
audio_info = AudioInfo(
|
audio_info = AudioInfo(
|
||||||
language=language,
|
language=language,
|
||||||
@@ -252,6 +254,7 @@ class WhisperModel:
|
|||||||
def generate_segments(
|
def generate_segments(
|
||||||
self,
|
self,
|
||||||
features: np.ndarray,
|
features: np.ndarray,
|
||||||
|
whisper_encoder: "WhisperEncoder",
|
||||||
tokenizer: Tokenizer,
|
tokenizer: Tokenizer,
|
||||||
options: TranscriptionOptions,
|
options: TranscriptionOptions,
|
||||||
) -> Iterable[Segment]:
|
) -> Iterable[Segment]:
|
||||||
@@ -281,8 +284,10 @@ class WhisperModel:
|
|||||||
prefix=options.prefix,
|
prefix=options.prefix,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
encoder_output = whisper_encoder(seek, segment)
|
||||||
|
|
||||||
result, avg_log_prob, temperature = self.generate_with_fallback(
|
result, avg_log_prob, temperature = self.generate_with_fallback(
|
||||||
segment, prompt, tokenizer, options
|
encoder_output, prompt, tokenizer, options
|
||||||
)
|
)
|
||||||
|
|
||||||
if options.no_speech_threshold is not None:
|
if options.no_speech_threshold is not None:
|
||||||
@@ -388,7 +393,7 @@ class WhisperModel:
|
|||||||
self.add_word_timestamps(
|
self.add_word_timestamps(
|
||||||
current_segments,
|
current_segments,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
segment,
|
encoder_output,
|
||||||
segment_size,
|
segment_size,
|
||||||
options.prepend_punctuations,
|
options.prepend_punctuations,
|
||||||
options.append_punctuations,
|
options.append_punctuations,
|
||||||
@@ -428,12 +433,11 @@ class WhisperModel:
|
|||||||
|
|
||||||
def generate_with_fallback(
|
def generate_with_fallback(
|
||||||
self,
|
self,
|
||||||
segment: np.ndarray,
|
encoder_output: ctranslate2.StorageView,
|
||||||
prompt: List[int],
|
prompt: List[int],
|
||||||
tokenizer: Tokenizer,
|
tokenizer: Tokenizer,
|
||||||
options: TranscriptionOptions,
|
options: TranscriptionOptions,
|
||||||
) -> Tuple[ctranslate2.models.WhisperGenerationResult, float, float]:
|
) -> Tuple[ctranslate2.models.WhisperGenerationResult, float, float]:
|
||||||
features = get_ctranslate2_storage(segment)
|
|
||||||
result = None
|
result = None
|
||||||
avg_log_prob = None
|
avg_log_prob = None
|
||||||
final_temperature = None
|
final_temperature = None
|
||||||
@@ -458,7 +462,7 @@ class WhisperModel:
|
|||||||
|
|
||||||
final_temperature = temperature
|
final_temperature = temperature
|
||||||
result = self.model.generate(
|
result = self.model.generate(
|
||||||
features,
|
encoder_output,
|
||||||
[prompt],
|
[prompt],
|
||||||
length_penalty=options.length_penalty,
|
length_penalty=options.length_penalty,
|
||||||
max_length=self.max_length,
|
max_length=self.max_length,
|
||||||
@@ -529,7 +533,7 @@ class WhisperModel:
|
|||||||
self,
|
self,
|
||||||
segments: List[dict],
|
segments: List[dict],
|
||||||
tokenizer: Tokenizer,
|
tokenizer: Tokenizer,
|
||||||
mel: np.ndarray,
|
encoder_output: ctranslate2.StorageView,
|
||||||
num_frames: int,
|
num_frames: int,
|
||||||
prepend_punctuations: str,
|
prepend_punctuations: str,
|
||||||
append_punctuations: str,
|
append_punctuations: str,
|
||||||
@@ -543,7 +547,9 @@ class WhisperModel:
|
|||||||
]
|
]
|
||||||
|
|
||||||
text_tokens = list(itertools.chain.from_iterable(text_tokens_per_segment))
|
text_tokens = list(itertools.chain.from_iterable(text_tokens_per_segment))
|
||||||
alignment = self.find_alignment(tokenizer, text_tokens, mel, num_frames)
|
alignment = self.find_alignment(
|
||||||
|
tokenizer, text_tokens, encoder_output, num_frames
|
||||||
|
)
|
||||||
merge_punctuations(alignment, prepend_punctuations, append_punctuations)
|
merge_punctuations(alignment, prepend_punctuations, append_punctuations)
|
||||||
|
|
||||||
time_offset = (
|
time_offset = (
|
||||||
@@ -585,7 +591,7 @@ class WhisperModel:
|
|||||||
self,
|
self,
|
||||||
tokenizer: Tokenizer,
|
tokenizer: Tokenizer,
|
||||||
text_tokens: List[int],
|
text_tokens: List[int],
|
||||||
mel: np.ndarray,
|
encoder_output: ctranslate2.StorageView,
|
||||||
num_frames: int,
|
num_frames: int,
|
||||||
median_filter_width: int = 7,
|
median_filter_width: int = 7,
|
||||||
) -> List[dict]:
|
) -> List[dict]:
|
||||||
@@ -593,7 +599,7 @@ class WhisperModel:
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
result = self.model.align(
|
result = self.model.align(
|
||||||
get_ctranslate2_storage(mel),
|
encoder_output,
|
||||||
tokenizer.sot_sequence,
|
tokenizer.sot_sequence,
|
||||||
[text_tokens],
|
[text_tokens],
|
||||||
num_frames,
|
num_frames,
|
||||||
@@ -646,9 +652,39 @@ class WhisperModel:
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class WhisperEncoder:
|
||||||
|
"""Helper class to cache and reuse the encoder output."""
|
||||||
|
|
||||||
|
def __init__(self, model: ctranslate2.models.Whisper):
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
# When the model is running on multiple GPUs, the encoder output should be moved
|
||||||
|
# to the CPU since we don't know which GPU will handle the next job.
|
||||||
|
self.cache_on_cpu = len(model.device_index) > 1
|
||||||
|
|
||||||
|
self.last_seek = -1
|
||||||
|
self.last_output = None
|
||||||
|
|
||||||
|
def __call__(self, seek: int, features: np.ndarray) -> ctranslate2.StorageView:
|
||||||
|
if self.last_seek == seek:
|
||||||
|
return self.last_output
|
||||||
|
|
||||||
|
features = np.expand_dims(features, 0)
|
||||||
|
features = get_ctranslate2_storage(features)
|
||||||
|
|
||||||
|
output = self.model.encode(features, to_cpu=self.cache_on_cpu)
|
||||||
|
|
||||||
|
if self.last_output is not None:
|
||||||
|
del self.last_output
|
||||||
|
|
||||||
|
self.last_seek = seek
|
||||||
|
self.last_output = output
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
def get_ctranslate2_storage(segment: np.ndarray) -> ctranslate2.StorageView:
|
def get_ctranslate2_storage(segment: np.ndarray) -> ctranslate2.StorageView:
|
||||||
segment = np.ascontiguousarray(segment)
|
segment = np.ascontiguousarray(segment)
|
||||||
segment = np.expand_dims(segment, 0)
|
|
||||||
segment = ctranslate2.StorageView.from_array(segment)
|
segment = ctranslate2.StorageView.from_array(segment)
|
||||||
return segment
|
return segment
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
av==10.*
|
av==10.*
|
||||||
ctranslate2>=3.9,<4
|
ctranslate2>=3.10,<4
|
||||||
tokenizers==0.13.*
|
tokenizers==0.13.*
|
||||||
|
|||||||
Reference in New Issue
Block a user