Revert "Merge remote-tracking branch 'upstream/master' into prompt"

This reverts commit 6e42088656, reversing
changes made to 4a59bb011d.
This commit is contained in:
2024-09-12 00:49:31 +08:00
parent 6e42088656
commit 28a4d11a73
13 changed files with 423 additions and 1599 deletions

View File

@@ -1,7 +1,19 @@
"""We use the PyAV library to decode the audio: https://github.com/PyAV-Org/PyAV
The advantage of PyAV is that it bundles the FFmpeg libraries so there is no additional
system dependencies. FFmpeg does not need to be installed on the system.
However, the API is quite low-level so we need to manipulate audio frames directly.
"""
import gc
import io
import itertools
from typing import BinaryIO, Union
import torch
import torchaudio
import av
import numpy as np
def decode_audio(
@@ -17,42 +29,91 @@ def decode_audio(
split_stereo: Return separate left and right channels.
Returns:
A float32 Torch Tensor.
A float32 Numpy array.
If `split_stereo` is enabled, the function returns a 2-tuple with the
separated left and right channels.
"""
resampler = av.audio.resampler.AudioResampler(
format="s16",
layout="mono" if not split_stereo else "stereo",
rate=sampling_rate,
)
waveform, audio_sf = torchaudio.load(input_file) # waveform: channels X T
raw_buffer = io.BytesIO()
dtype = None
with av.open(input_file, mode="r", metadata_errors="ignore") as container:
frames = container.decode(audio=0)
frames = _ignore_invalid_frames(frames)
frames = _group_frames(frames, 500000)
frames = _resample_frames(frames, resampler)
for frame in frames:
array = frame.to_ndarray()
dtype = array.dtype
raw_buffer.write(array)
# It appears that some objects related to the resampler are not freed
# unless the garbage collector is manually run.
del resampler
gc.collect()
audio = np.frombuffer(raw_buffer.getbuffer(), dtype=dtype)
# Convert s16 back to f32.
audio = audio.astype(np.float32) / 32768.0
if audio_sf != sampling_rate:
waveform = torchaudio.functional.resample(
waveform, orig_freq=audio_sf, new_freq=sampling_rate
)
if split_stereo:
return waveform[0], waveform[1]
left_channel = audio[0::2]
right_channel = audio[1::2]
return left_channel, right_channel
return waveform.mean(0)
return audio
def _ignore_invalid_frames(frames):
iterator = iter(frames)
while True:
try:
yield next(iterator)
except StopIteration:
break
except av.error.InvalidDataError:
continue
def _group_frames(frames, num_samples=None):
fifo = av.audio.fifo.AudioFifo()
for frame in frames:
frame.pts = None # Ignore timestamp check.
fifo.write(frame)
if num_samples is not None and fifo.samples >= num_samples:
yield fifo.read()
if fifo.samples > 0:
yield fifo.read()
def _resample_frames(frames, resampler):
# Add None to flush the resampler.
for frame in itertools.chain(frames, [None]):
yield from resampler.resample(frame)
def pad_or_trim(array, length: int, *, axis: int = -1):
"""
Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
"""
axis = axis % array.ndim
if array.shape[axis] > length:
idx = [Ellipsis] * axis + [slice(length)] + [Ellipsis] * (array.ndim - axis - 1)
return array[idx]
array = array.take(indices=range(length), axis=axis)
if array.shape[axis] < length:
pad_widths = (
[
0,
]
* array.ndim
* 2
)
pad_widths[2 * axis] = length - array.shape[axis]
array = torch.nn.functional.pad(array, tuple(pad_widths[::-1]))
pad_widths = [(0, 0)] * array.ndim
pad_widths[axis] = (0, length - array.shape[axis])
array = np.pad(array, pad_widths)
return array