diff --git a/whisper/audio.py b/whisper/audio.py index a6074e8..a3d8a13 100644 --- a/whisper/audio.py +++ b/whisper/audio.py @@ -55,7 +55,7 @@ def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1): """ if torch.is_tensor(array): if array.shape[axis] > length: - array = array.index_select(dim=axis, index=torch.arange(length)) + array = array.index_select(dim=axis, index=torch.arange(length, device=array.device)) if array.shape[axis] < length: pad_widths = [(0, 0)] * array.ndim