diff --git a/faster_whisper/audio.py b/faster_whisper/audio.py index 3190619..a597fd8 100644 --- a/faster_whisper/audio.py +++ b/faster_whisper/audio.py @@ -102,3 +102,18 @@ def _resample_frames(frames, resampler): # Add None to flush the resampler. for frame in itertools.chain(frames, [None]): yield from resampler.resample(frame) + + +def pad_or_trim(array, length: int, *, axis: int = -1): + """ + Pad or trim the audio array to N_SAMPLES, as expected by the encoder. + """ + if array.shape[axis] > length: + array = array.take(indices=range(length), axis=axis) + + if array.shape[axis] < length: + pad_widths = [(0, 0)] * array.ndim + pad_widths[axis] = (0, length - array.shape[axis]) + array = np.pad(array, pad_widths) + + return array diff --git a/faster_whisper/transcribe.py b/faster_whisper/transcribe.py index c1ea390..bce84d2 100644 --- a/faster_whisper/transcribe.py +++ b/faster_whisper/transcribe.py @@ -11,7 +11,7 @@ import ctranslate2 import numpy as np import tokenizers -from faster_whisper.audio import decode_audio +from faster_whisper.audio import decode_audio, pad_or_trim from faster_whisper.feature_extractor import FeatureExtractor from faster_whisper.tokenizer import _LANGUAGE_CODES, Tokenizer from faster_whisper.utils import download_model, format_timestamp, get_end, get_logger @@ -492,6 +492,7 @@ class WhisperModel: ) segment = features[:, seek : seek + segment_size] segment_duration = segment_size * self.feature_extractor.time_per_frame + segment = pad_or_trim(segment, self.feature_extractor.nb_max_frames) if self.logger.isEnabledFor(logging.DEBUG): self.logger.debug(