Revert "Merge remote-tracking branch 'upstream/master' into prompt"
This reverts commit6e42088656, reversing changes made to4a59bb011d.
This commit is contained in:
@@ -1,7 +1,19 @@
|
||||
"""We use the PyAV library to decode the audio: https://github.com/PyAV-Org/PyAV
|
||||
|
||||
The advantage of PyAV is that it bundles the FFmpeg libraries so there is no additional
|
||||
system dependencies. FFmpeg does not need to be installed on the system.
|
||||
|
||||
However, the API is quite low-level so we need to manipulate audio frames directly.
|
||||
"""
|
||||
|
||||
import gc
|
||||
import io
|
||||
import itertools
|
||||
|
||||
from typing import BinaryIO, Union
|
||||
|
||||
import torch
|
||||
import torchaudio
|
||||
import av
|
||||
import numpy as np
|
||||
|
||||
|
||||
def decode_audio(
|
||||
@@ -17,42 +29,91 @@ def decode_audio(
|
||||
split_stereo: Return separate left and right channels.
|
||||
|
||||
Returns:
|
||||
A float32 Torch Tensor.
|
||||
A float32 Numpy array.
|
||||
|
||||
If `split_stereo` is enabled, the function returns a 2-tuple with the
|
||||
separated left and right channels.
|
||||
"""
|
||||
resampler = av.audio.resampler.AudioResampler(
|
||||
format="s16",
|
||||
layout="mono" if not split_stereo else "stereo",
|
||||
rate=sampling_rate,
|
||||
)
|
||||
|
||||
waveform, audio_sf = torchaudio.load(input_file) # waveform: channels X T
|
||||
raw_buffer = io.BytesIO()
|
||||
dtype = None
|
||||
|
||||
with av.open(input_file, mode="r", metadata_errors="ignore") as container:
|
||||
frames = container.decode(audio=0)
|
||||
frames = _ignore_invalid_frames(frames)
|
||||
frames = _group_frames(frames, 500000)
|
||||
frames = _resample_frames(frames, resampler)
|
||||
|
||||
for frame in frames:
|
||||
array = frame.to_ndarray()
|
||||
dtype = array.dtype
|
||||
raw_buffer.write(array)
|
||||
|
||||
# It appears that some objects related to the resampler are not freed
|
||||
# unless the garbage collector is manually run.
|
||||
del resampler
|
||||
gc.collect()
|
||||
|
||||
audio = np.frombuffer(raw_buffer.getbuffer(), dtype=dtype)
|
||||
|
||||
# Convert s16 back to f32.
|
||||
audio = audio.astype(np.float32) / 32768.0
|
||||
|
||||
if audio_sf != sampling_rate:
|
||||
waveform = torchaudio.functional.resample(
|
||||
waveform, orig_freq=audio_sf, new_freq=sampling_rate
|
||||
)
|
||||
if split_stereo:
|
||||
return waveform[0], waveform[1]
|
||||
left_channel = audio[0::2]
|
||||
right_channel = audio[1::2]
|
||||
return left_channel, right_channel
|
||||
|
||||
return waveform.mean(0)
|
||||
return audio
|
||||
|
||||
|
||||
def _ignore_invalid_frames(frames):
|
||||
iterator = iter(frames)
|
||||
|
||||
while True:
|
||||
try:
|
||||
yield next(iterator)
|
||||
except StopIteration:
|
||||
break
|
||||
except av.error.InvalidDataError:
|
||||
continue
|
||||
|
||||
|
||||
def _group_frames(frames, num_samples=None):
|
||||
fifo = av.audio.fifo.AudioFifo()
|
||||
|
||||
for frame in frames:
|
||||
frame.pts = None # Ignore timestamp check.
|
||||
fifo.write(frame)
|
||||
|
||||
if num_samples is not None and fifo.samples >= num_samples:
|
||||
yield fifo.read()
|
||||
|
||||
if fifo.samples > 0:
|
||||
yield fifo.read()
|
||||
|
||||
|
||||
def _resample_frames(frames, resampler):
|
||||
# Add None to flush the resampler.
|
||||
for frame in itertools.chain(frames, [None]):
|
||||
yield from resampler.resample(frame)
|
||||
|
||||
|
||||
def pad_or_trim(array, length: int, *, axis: int = -1):
|
||||
"""
|
||||
Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
|
||||
"""
|
||||
axis = axis % array.ndim
|
||||
if array.shape[axis] > length:
|
||||
idx = [Ellipsis] * axis + [slice(length)] + [Ellipsis] * (array.ndim - axis - 1)
|
||||
return array[idx]
|
||||
array = array.take(indices=range(length), axis=axis)
|
||||
|
||||
if array.shape[axis] < length:
|
||||
pad_widths = (
|
||||
[
|
||||
0,
|
||||
]
|
||||
* array.ndim
|
||||
* 2
|
||||
)
|
||||
pad_widths[2 * axis] = length - array.shape[axis]
|
||||
array = torch.nn.functional.pad(array, tuple(pad_widths[::-1]))
|
||||
pad_widths = [(0, 0)] * array.ndim
|
||||
pad_widths[axis] = (0, length - array.shape[axis])
|
||||
array = np.pad(array, pad_widths)
|
||||
|
||||
return array
|
||||
|
||||
Reference in New Issue
Block a user