diff --git a/faster_whisper/feature_extractor.py b/faster_whisper/feature_extractor.py index 6260009..a525d53 100644 --- a/faster_whisper/feature_extractor.py +++ b/faster_whisper/feature_extractor.py @@ -142,11 +142,14 @@ class FeatureExtractor: data[f] = np.fft.fft(fft_signal, axis=0)[:num_fft_bins] return data.T - def __call__(self, waveform): + def __call__(self, waveform, padding=True): """ Compute the log-Mel spectrogram of the provided audio, gives similar results whisper's original torch implementation with 1e-5 tolerance. """ + if padding: + waveform = np.pad(waveform, [(0, self.n_samples)]) + window = np.hanning(self.n_fft + 1)[:-1] frames = self.fram_wave(waveform) diff --git a/faster_whisper/transcribe.py b/faster_whisper/transcribe.py index 032bebd..ebc1272 100644 --- a/faster_whisper/transcribe.py +++ b/faster_whisper/transcribe.py @@ -181,7 +181,7 @@ class WhisperModel: language = "en" language_probability = 1 else: - segment = self.get_segment(features) + segment = features[:, : self.feature_extractor.nb_max_frames] input = self.get_input(segment) results = self.model.detect_language(input) language_token, language_probability = results[0][0] @@ -237,7 +237,7 @@ class WhisperModel: ) def generate_tokenized_segments(self, features, options): - num_frames = features.shape[-1] + content_frames = features.shape[-1] - self.feature_extractor.nb_max_frames offset = 0 all_tokens = [] prompt_reset_since = 0 @@ -247,10 +247,15 @@ class WhisperModel: initial_prompt_tokens = self.encode_text(initial_prompt) all_tokens.extend(initial_prompt_tokens) - while offset < num_frames: + while offset < content_frames: time_offset = offset * self.feature_extractor.time_per_frame - segment = self.get_segment(features, offset) - segment_duration = segment.shape[-1] * self.feature_extractor.time_per_frame + segment = features[ + :, offset : offset + self.feature_extractor.nb_max_frames + ] + segment_size = min( + self.feature_extractor.nb_max_frames, content_frames - offset + ) + segment_duration = segment_size * self.feature_extractor.time_per_frame previous_tokens = all_tokens[prompt_reset_since:] prompt = self.get_prompt( @@ -278,7 +283,7 @@ class WhisperModel: if should_skip: # fast-forward to the next segment boundary - offset += segment.shape[-1] + offset += segment_size continue tokens = result.sequences_ids[0] @@ -320,7 +325,7 @@ class WhisperModel: if ended_with_single_timestamp: # single timestamp at the end means no speech after the last timestamp. - offset += segment.shape[-1] + offset += segment_size else: # otherwise, ignore the unfinished segment and seek to the last timestamp last_timestamp_position = ( @@ -341,7 +346,7 @@ class WhisperModel: yield time_offset, time_offset + duration, tokens - offset += segment.shape[-1] + offset += segment_size all_tokens.extend(tokens) if not options.condition_on_previous_text or temperature > 0.5: @@ -456,23 +461,8 @@ class WhisperModel: return prompt - def get_segment(self, features, offset=0): - if offset > 0: - features = features[:, offset:] - - num_frames = features.shape[-1] - required_num_frames = self.feature_extractor.nb_max_frames - - if num_frames > required_num_frames: - features = features[:, :required_num_frames] - elif num_frames < required_num_frames: - pad_widths = [(0, 0), (0, required_num_frames - num_frames)] - features = np.pad(features, pad_widths) - - features = np.ascontiguousarray(features) - return features - def get_input(self, segment): + segment = np.ascontiguousarray(segment) segment = np.expand_dims(segment, 0) segment = ctranslate2.StorageView.from_array(segment) return segment