Pad the audio instead of the spectrogram

See 919a713499
This commit is contained in:
Guillaume Klein
2023-03-08 10:50:46 +01:00
parent 2646906596
commit 6b16b8a69c
2 changed files with 18 additions and 25 deletions

View File

@@ -142,11 +142,14 @@ class FeatureExtractor:
data[f] = np.fft.fft(fft_signal, axis=0)[:num_fft_bins] data[f] = np.fft.fft(fft_signal, axis=0)[:num_fft_bins]
return data.T return data.T
def __call__(self, waveform): def __call__(self, waveform, padding=True):
""" """
Compute the log-Mel spectrogram of the provided audio, gives similar results Compute the log-Mel spectrogram of the provided audio, gives similar results
whisper's original torch implementation with 1e-5 tolerance. whisper's original torch implementation with 1e-5 tolerance.
""" """
if padding:
waveform = np.pad(waveform, [(0, self.n_samples)])
window = np.hanning(self.n_fft + 1)[:-1] window = np.hanning(self.n_fft + 1)[:-1]
frames = self.fram_wave(waveform) frames = self.fram_wave(waveform)

View File

@@ -181,7 +181,7 @@ class WhisperModel:
language = "en" language = "en"
language_probability = 1 language_probability = 1
else: else:
segment = self.get_segment(features) segment = features[:, : self.feature_extractor.nb_max_frames]
input = self.get_input(segment) input = self.get_input(segment)
results = self.model.detect_language(input) results = self.model.detect_language(input)
language_token, language_probability = results[0][0] language_token, language_probability = results[0][0]
@@ -237,7 +237,7 @@ class WhisperModel:
) )
def generate_tokenized_segments(self, features, options): def generate_tokenized_segments(self, features, options):
num_frames = features.shape[-1] content_frames = features.shape[-1] - self.feature_extractor.nb_max_frames
offset = 0 offset = 0
all_tokens = [] all_tokens = []
prompt_reset_since = 0 prompt_reset_since = 0
@@ -247,10 +247,15 @@ class WhisperModel:
initial_prompt_tokens = self.encode_text(initial_prompt) initial_prompt_tokens = self.encode_text(initial_prompt)
all_tokens.extend(initial_prompt_tokens) all_tokens.extend(initial_prompt_tokens)
while offset < num_frames: while offset < content_frames:
time_offset = offset * self.feature_extractor.time_per_frame time_offset = offset * self.feature_extractor.time_per_frame
segment = self.get_segment(features, offset) segment = features[
segment_duration = segment.shape[-1] * self.feature_extractor.time_per_frame :, offset : offset + self.feature_extractor.nb_max_frames
]
segment_size = min(
self.feature_extractor.nb_max_frames, content_frames - offset
)
segment_duration = segment_size * self.feature_extractor.time_per_frame
previous_tokens = all_tokens[prompt_reset_since:] previous_tokens = all_tokens[prompt_reset_since:]
prompt = self.get_prompt( prompt = self.get_prompt(
@@ -278,7 +283,7 @@ class WhisperModel:
if should_skip: if should_skip:
# fast-forward to the next segment boundary # fast-forward to the next segment boundary
offset += segment.shape[-1] offset += segment_size
continue continue
tokens = result.sequences_ids[0] tokens = result.sequences_ids[0]
@@ -320,7 +325,7 @@ class WhisperModel:
if ended_with_single_timestamp: if ended_with_single_timestamp:
# single timestamp at the end means no speech after the last timestamp. # single timestamp at the end means no speech after the last timestamp.
offset += segment.shape[-1] offset += segment_size
else: else:
# otherwise, ignore the unfinished segment and seek to the last timestamp # otherwise, ignore the unfinished segment and seek to the last timestamp
last_timestamp_position = ( last_timestamp_position = (
@@ -341,7 +346,7 @@ class WhisperModel:
yield time_offset, time_offset + duration, tokens yield time_offset, time_offset + duration, tokens
offset += segment.shape[-1] offset += segment_size
all_tokens.extend(tokens) all_tokens.extend(tokens)
if not options.condition_on_previous_text or temperature > 0.5: if not options.condition_on_previous_text or temperature > 0.5:
@@ -456,23 +461,8 @@ class WhisperModel:
return prompt return prompt
def get_segment(self, features, offset=0):
if offset > 0:
features = features[:, offset:]
num_frames = features.shape[-1]
required_num_frames = self.feature_extractor.nb_max_frames
if num_frames > required_num_frames:
features = features[:, :required_num_frames]
elif num_frames < required_num_frames:
pad_widths = [(0, 0), (0, required_num_frames - num_frames)]
features = np.pad(features, pad_widths)
features = np.ascontiguousarray(features)
return features
def get_input(self, segment): def get_input(self, segment):
segment = np.ascontiguousarray(segment)
segment = np.expand_dims(segment, 0) segment = np.expand_dims(segment, 0)
segment = ctranslate2.StorageView.from_array(segment) segment = ctranslate2.StorageView.from_array(segment)
return segment return segment