Add pad_or_trim function to handle segment before encoding (#705)
This commit is contained in:
@@ -102,3 +102,18 @@ def _resample_frames(frames, resampler):
|
|||||||
# Add None to flush the resampler.
|
# Add None to flush the resampler.
|
||||||
for frame in itertools.chain(frames, [None]):
|
for frame in itertools.chain(frames, [None]):
|
||||||
yield from resampler.resample(frame)
|
yield from resampler.resample(frame)
|
||||||
|
|
||||||
|
|
||||||
|
def pad_or_trim(array, length: int, *, axis: int = -1):
|
||||||
|
"""
|
||||||
|
Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
|
||||||
|
"""
|
||||||
|
if array.shape[axis] > length:
|
||||||
|
array = array.take(indices=range(length), axis=axis)
|
||||||
|
|
||||||
|
if array.shape[axis] < length:
|
||||||
|
pad_widths = [(0, 0)] * array.ndim
|
||||||
|
pad_widths[axis] = (0, length - array.shape[axis])
|
||||||
|
array = np.pad(array, pad_widths)
|
||||||
|
|
||||||
|
return array
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ import ctranslate2
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import tokenizers
|
import tokenizers
|
||||||
|
|
||||||
from faster_whisper.audio import decode_audio
|
from faster_whisper.audio import decode_audio, pad_or_trim
|
||||||
from faster_whisper.feature_extractor import FeatureExtractor
|
from faster_whisper.feature_extractor import FeatureExtractor
|
||||||
from faster_whisper.tokenizer import _LANGUAGE_CODES, Tokenizer
|
from faster_whisper.tokenizer import _LANGUAGE_CODES, Tokenizer
|
||||||
from faster_whisper.utils import download_model, format_timestamp, get_end, get_logger
|
from faster_whisper.utils import download_model, format_timestamp, get_end, get_logger
|
||||||
@@ -492,6 +492,7 @@ class WhisperModel:
|
|||||||
)
|
)
|
||||||
segment = features[:, seek : seek + segment_size]
|
segment = features[:, seek : seek + segment_size]
|
||||||
segment_duration = segment_size * self.feature_extractor.time_per_frame
|
segment_duration = segment_size * self.feature_extractor.time_per_frame
|
||||||
|
segment = pad_or_trim(segment, self.feature_extractor.nb_max_frames)
|
||||||
|
|
||||||
if self.logger.isEnabledFor(logging.DEBUG):
|
if self.logger.isEnabledFor(logging.DEBUG):
|
||||||
self.logger.debug(
|
self.logger.debug(
|
||||||
|
|||||||
Reference in New Issue
Block a user