Fix bug (#305)
Fix bug: RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper__index_select)
This commit is contained in:
@@ -55,7 +55,7 @@ def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
|
|||||||
"""
|
"""
|
||||||
if torch.is_tensor(array):
|
if torch.is_tensor(array):
|
||||||
if array.shape[axis] > length:
|
if array.shape[axis] > length:
|
||||||
array = array.index_select(dim=axis, index=torch.arange(length))
|
array = array.index_select(dim=axis, index=torch.arange(length, device=array.device))
|
||||||
|
|
||||||
if array.shape[axis] < length:
|
if array.shape[axis] < length:
|
||||||
pad_widths = [(0, 0)] * array.ndim
|
pad_widths = [(0, 0)] * array.ndim
|
||||||
|
|||||||
Reference in New Issue
Block a user