diff --git a/demucs/__pycache__/apply.cpython-39.pyc b/demucs/__pycache__/apply.cpython-39.pyc deleted file mode 100644 index cbd1c41..0000000 Binary files a/demucs/__pycache__/apply.cpython-39.pyc and /dev/null differ diff --git a/demucs/__pycache__/audio.cpython-39.pyc b/demucs/__pycache__/audio.cpython-39.pyc deleted file mode 100644 index d257ebb..0000000 Binary files a/demucs/__pycache__/audio.cpython-39.pyc and /dev/null differ diff --git a/demucs/__pycache__/demucs.cpython-39.pyc b/demucs/__pycache__/demucs.cpython-39.pyc deleted file mode 100644 index f608b5a..0000000 Binary files a/demucs/__pycache__/demucs.cpython-39.pyc and /dev/null differ diff --git a/demucs/__pycache__/hdemucs.cpython-39.pyc b/demucs/__pycache__/hdemucs.cpython-39.pyc deleted file mode 100644 index c9d8ff4..0000000 Binary files a/demucs/__pycache__/hdemucs.cpython-39.pyc and /dev/null differ diff --git a/demucs/__pycache__/pretrained.cpython-39.pyc b/demucs/__pycache__/pretrained.cpython-39.pyc deleted file mode 100644 index d3cc165..0000000 Binary files a/demucs/__pycache__/pretrained.cpython-39.pyc and /dev/null differ diff --git a/demucs/__pycache__/repo.cpython-39.pyc b/demucs/__pycache__/repo.cpython-39.pyc deleted file mode 100644 index 18ef7a3..0000000 Binary files a/demucs/__pycache__/repo.cpython-39.pyc and /dev/null differ diff --git a/demucs/__pycache__/spec.cpython-39.pyc b/demucs/__pycache__/spec.cpython-39.pyc deleted file mode 100644 index 1c75010..0000000 Binary files a/demucs/__pycache__/spec.cpython-39.pyc and /dev/null differ diff --git a/demucs/__pycache__/states.cpython-39.pyc b/demucs/__pycache__/states.cpython-39.pyc deleted file mode 100644 index d013690..0000000 Binary files a/demucs/__pycache__/states.cpython-39.pyc and /dev/null differ diff --git a/demucs/__pycache__/utils.cpython-39.pyc b/demucs/__pycache__/utils.cpython-39.pyc deleted file mode 100644 index 0ad9279..0000000 Binary files a/demucs/__pycache__/utils.cpython-39.pyc and /dev/null differ diff --git a/demucs/apply.py b/demucs/apply.py deleted file mode 100644 index 10ebd66..0000000 --- a/demucs/apply.py +++ /dev/null @@ -1,237 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -""" -Code to apply a model to a mix. It will handle chunking with overlaps and -inteprolation between chunks, as well as the "shift trick". -""" -from concurrent.futures import ThreadPoolExecutor -import random -import typing as tp - -import torch as th -from torch import nn -from torch.nn import functional as F -import tqdm - -from .demucs import Demucs -from .hdemucs import HDemucs -from .utils import center_trim, DummyPoolExecutor - -Model = tp.Union[Demucs, HDemucs] - - -class BagOfModels(nn.Module): - def __init__(self, models: tp.List[Model], - weights: tp.Optional[tp.List[tp.List[float]]] = None, - segment: tp.Optional[float] = None): - """ - Represents a bag of models with specific weights. - You should call `apply_model` rather than calling directly the forward here for - optimal performance. - - Args: - models (list[nn.Module]): list of Demucs/HDemucs models. - weights (list[list[float]]): list of weights. If None, assumed to - be all ones, otherwise it should be a list of N list (N number of models), - each containing S floats (S number of sources). - segment (None or float): overrides the `segment` attribute of each model - (this is performed inplace, be careful is you reuse the models passed). - """ - super().__init__() - assert len(models) > 0 - first = models[0] - for other in models: - assert other.sources == first.sources - assert other.samplerate == first.samplerate - assert other.audio_channels == first.audio_channels - if segment is not None: - other.segment = segment - - self.audio_channels = first.audio_channels - self.samplerate = first.samplerate - self.sources = first.sources - self.models = nn.ModuleList(models) - - if weights is None: - weights = [[1. for _ in first.sources] for _ in models] - else: - assert len(weights) == len(models) - for weight in weights: - assert len(weight) == len(first.sources) - self.weights = weights - - def forward(self, x): - raise NotImplementedError("Call `apply_model` on this.") - - -class TensorChunk: - def __init__(self, tensor, offset=0, length=None): - total_length = tensor.shape[-1] - assert offset >= 0 - assert offset < total_length - - if length is None: - length = total_length - offset - else: - length = min(total_length - offset, length) - - self.tensor = tensor - self.offset = offset - self.length = length - self.device = tensor.device - - @property - def shape(self): - shape = list(self.tensor.shape) - shape[-1] = self.length - return shape - - def padded(self, target_length): - delta = target_length - self.length - total_length = self.tensor.shape[-1] - assert delta >= 0 - - start = self.offset - delta // 2 - end = start + target_length - - correct_start = max(0, start) - correct_end = min(total_length, end) - - pad_left = correct_start - start - pad_right = end - correct_end - - out = F.pad(self.tensor[..., correct_start:correct_end], (pad_left, pad_right)) - assert out.shape[-1] == target_length - return out - - -def tensor_chunk(tensor_or_chunk): - if isinstance(tensor_or_chunk, TensorChunk): - return tensor_or_chunk - else: - assert isinstance(tensor_or_chunk, th.Tensor) - return TensorChunk(tensor_or_chunk) - - -def apply_model(model, mix, shifts=1, split=True, - overlap=0.25, transition_power=1., progress=False, device=None, - num_workers=0, pool=None): - """ - Apply model to a given mixture. - - Args: - shifts (int): if > 0, will shift in time `mix` by a random amount between 0 and 0.5 sec - and apply the oppositve shift to the output. This is repeated `shifts` time and - all predictions are averaged. This effectively makes the model time equivariant - and improves SDR by up to 0.2 points. - split (bool): if True, the input will be broken down in 8 seconds extracts - and predictions will be performed individually on each and concatenated. - Useful for model with large memory footprint like Tasnet. - progress (bool): if True, show a progress bar (requires split=True) - device (torch.device, str, or None): if provided, device on which to - execute the computation, otherwise `mix.device` is assumed. - When `device` is different from `mix.device`, only local computations will - be on `device`, while the entire tracks will be stored on `mix.device`. - """ - if device is None: - device = mix.device - else: - device = th.device(device) - if pool is None: - if num_workers > 0 and device.type == 'cpu': - pool = ThreadPoolExecutor(num_workers) - else: - pool = DummyPoolExecutor() - kwargs = { - 'shifts': shifts, - 'split': split, - 'overlap': overlap, - 'transition_power': transition_power, - 'progress': progress, - 'device': device, - 'pool': pool, - } - if isinstance(model, BagOfModels): - # Special treatment for bag of model. - # We explicitely apply multiple times `apply_model` so that the random shifts - # are different for each model. - estimates = 0 - totals = [0] * len(model.sources) - for sub_model, weight in zip(model.models, model.weights): - original_model_device = next(iter(sub_model.parameters())).device - sub_model.to(device) - - out = apply_model(sub_model, mix, **kwargs) - sub_model.to(original_model_device) - for k, inst_weight in enumerate(weight): - out[:, k, :, :] *= inst_weight - totals[k] += inst_weight - estimates += out - del out - - for k in range(estimates.shape[1]): - estimates[:, k, :, :] /= totals[k] - return estimates - - model.to(device) - assert transition_power >= 1, "transition_power < 1 leads to weird behavior." - batch, channels, length = mix.shape - if split: - kwargs['split'] = False - out = th.zeros(batch, len(model.sources), channels, length, device=mix.device) - sum_weight = th.zeros(length, device=mix.device) - segment = int(model.samplerate * model.segment) - stride = int((1 - overlap) * segment) - offsets = range(0, length, stride) - scale = stride / model.samplerate - # We start from a triangle shaped weight, with maximal weight in the middle - # of the segment. Then we normalize and take to the power `transition_power`. - # Large values of transition power will lead to sharper transitions. - weight = th.cat([th.arange(1, segment // 2 + 1, device=device), - th.arange(segment - segment // 2, 0, -1, device=device)]) - assert len(weight) == segment - # If the overlap < 50%, this will translate to linear transition when - # transition_power is 1. - weight = (weight / weight.max())**transition_power - futures = [] - for offset in offsets: - chunk = TensorChunk(mix, offset, segment) - future = pool.submit(apply_model, model, chunk, **kwargs) - futures.append((future, offset)) - offset += segment - if progress: - futures = tqdm.tqdm(futures, unit_scale=scale, ncols=120, unit='seconds') - for future, offset in futures: - chunk_out = future.result() - chunk_length = chunk_out.shape[-1] - out[..., offset:offset + segment] += (weight[:chunk_length] * chunk_out).to(mix.device) - sum_weight[offset:offset + segment] += weight[:chunk_length].to(mix.device) - assert sum_weight.min() > 0 - out /= sum_weight - return out - elif shifts: - kwargs['shifts'] = 0 - max_shift = int(0.5 * model.samplerate) - mix = tensor_chunk(mix) - padded_mix = mix.padded(length + 2 * max_shift) - out = 0 - for _ in range(shifts): - offset = random.randint(0, max_shift) - shifted = TensorChunk(padded_mix, offset, length + max_shift - offset) - shifted_out = apply_model(model, shifted, **kwargs) - out += shifted_out[..., max_shift - offset:] - out /= shifts - return out - else: - if hasattr(model, 'valid_length'): - valid_length = model.valid_length(length) - else: - valid_length = length - mix = tensor_chunk(mix) - padded_mix = mix.padded(valid_length).to(device) - with th.no_grad(): - out = model(padded_mix) - return center_trim(out, length) diff --git a/demucs/audio.py b/demucs/audio.py deleted file mode 100644 index d1ba194..0000000 --- a/demucs/audio.py +++ /dev/null @@ -1,256 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -import json -import subprocess as sp -from pathlib import Path - -import lameenc -import julius -import numpy as np -import torch -import torchaudio as ta - -from .utils import temp_filenames - - -def _read_info(path): - stdout_data = sp.check_output([ - 'ffprobe', "-loglevel", "panic", - str(path), '-print_format', 'json', '-show_format', '-show_streams' - ]) - return json.loads(stdout_data.decode('utf-8')) - - -class AudioFile: - """ - Allows to read audio from any format supported by ffmpeg, as well as resampling or - converting to mono on the fly. See :method:`read` for more details. - """ - def __init__(self, path: Path): - self.path = Path(path) - self._info = None - - def __repr__(self): - features = [("path", self.path)] - features.append(("samplerate", self.samplerate())) - features.append(("channels", self.channels())) - features.append(("streams", len(self))) - features_str = ", ".join(f"{name}={value}" for name, value in features) - return f"AudioFile({features_str})" - - @property - def info(self): - if self._info is None: - self._info = _read_info(self.path) - return self._info - - @property - def duration(self): - return float(self.info['format']['duration']) - - @property - def _audio_streams(self): - return [ - index for index, stream in enumerate(self.info["streams"]) - if stream["codec_type"] == "audio" - ] - - def __len__(self): - return len(self._audio_streams) - - def channels(self, stream=0): - return int(self.info['streams'][self._audio_streams[stream]]['channels']) - - def samplerate(self, stream=0): - return int(self.info['streams'][self._audio_streams[stream]]['sample_rate']) - - def read(self, - seek_time=None, - duration=None, - streams=slice(None), - samplerate=None, - channels=None, - temp_folder=None): - """ - Slightly more efficient implementation than stempeg, - in particular, this will extract all stems at once - rather than having to loop over one file multiple times - for each stream. - - Args: - seek_time (float): seek time in seconds or None if no seeking is needed. - duration (float): duration in seconds to extract or None to extract until the end. - streams (slice, int or list): streams to extract, can be a single int, a list or - a slice. If it is a slice or list, the output will be of size [S, C, T] - with S the number of streams, C the number of channels and T the number of samples. - If it is an int, the output will be [C, T]. - samplerate (int): if provided, will resample on the fly. If None, no resampling will - be done. Original sampling rate can be obtained with :method:`samplerate`. - channels (int): if 1, will convert to mono. We do not rely on ffmpeg for that - as ffmpeg automatically scale by +3dB to conserve volume when playing on speakers. - See https://sound.stackexchange.com/a/42710. - Our definition of mono is simply the average of the two channels. Any other - value will be ignored. - temp_folder (str or Path or None): temporary folder to use for decoding. - - - """ - streams = np.array(range(len(self)))[streams] - single = not isinstance(streams, np.ndarray) - if single: - streams = [streams] - - if duration is None: - target_size = None - query_duration = None - else: - target_size = int((samplerate or self.samplerate()) * duration) - query_duration = float((target_size + 1) / (samplerate or self.samplerate())) - - with temp_filenames(len(streams)) as filenames: - command = ['ffmpeg', '-y'] - command += ['-loglevel', 'panic'] - if seek_time: - command += ['-ss', str(seek_time)] - command += ['-i', str(self.path)] - for stream, filename in zip(streams, filenames): - command += ['-map', f'0:{self._audio_streams[stream]}'] - if query_duration is not None: - command += ['-t', str(query_duration)] - command += ['-threads', '1'] - command += ['-f', 'f32le'] - if samplerate is not None: - command += ['-ar', str(samplerate)] - command += [filename] - - sp.run(command, check=True) - wavs = [] - for filename in filenames: - wav = np.fromfile(filename, dtype=np.float32) - wav = torch.from_numpy(wav) - wav = wav.view(-1, self.channels()).t() - if channels is not None: - wav = convert_audio_channels(wav, channels) - if target_size is not None: - wav = wav[..., :target_size] - wavs.append(wav) - wav = torch.stack(wavs, dim=0) - if single: - wav = wav[0] - return wav - - -def convert_audio_channels(wav, channels=2): - """Convert audio to the given number of channels.""" - *shape, src_channels, length = wav.shape - if src_channels == channels: - pass - elif channels == 1: - # Case 1: - # The caller asked 1-channel audio, but the stream have multiple - # channels, downmix all channels. - wav = wav.mean(dim=-2, keepdim=True) - elif src_channels == 1: - # Case 2: - # The caller asked for multiple channels, but the input file have - # one single channel, replicate the audio over all channels. - wav = wav.expand(*shape, channels, length) - elif src_channels >= channels: - # Case 3: - # The caller asked for multiple channels, and the input file have - # more channels than requested. In that case return the first channels. - wav = wav[..., :channels, :] - else: - # Case 4: What is a reasonable choice here? - raise ValueError('The audio file has less channels than requested but is not mono.') - return wav - - -def convert_audio(wav, from_samplerate, to_samplerate, channels): - """Convert audio from a given samplerate to a target one and target number of channels.""" - wav = convert_audio_channels(wav, channels) - return julius.resample_frac(wav, from_samplerate, to_samplerate) - - -def i16_pcm(wav): - """Convert audio to 16 bits integer PCM format.""" - if wav.dtype.is_floating_point: - return (wav.clamp_(-1, 1) * (2**15 - 1)).short() - else: - return wav - - -def f32_pcm(wav): - """Convert audio to float 32 bits PCM format.""" - if wav.dtype.is_floating_point: - return wav - else: - return wav.float() / (2**15 - 1) - - -def as_dtype_pcm(wav, dtype): - """Convert audio to either f32 pcm or i16 pcm depending on the given dtype.""" - if wav.dtype.is_floating_point: - return f32_pcm(wav) - else: - return i16_pcm(wav) - - -def encode_mp3(wav, path, samplerate=44100, bitrate=320, verbose=False): - """Save given audio as mp3. This should work on all OSes.""" - C, T = wav.shape - wav = i16_pcm(wav) - encoder = lameenc.Encoder() - encoder.set_bit_rate(bitrate) - encoder.set_in_sample_rate(samplerate) - encoder.set_channels(C) - encoder.set_quality(2) # 2-highest, 7-fastest - if not verbose: - encoder.silence() - wav = wav.transpose(0, 1).numpy() - mp3_data = encoder.encode(wav.tobytes()) - mp3_data += encoder.flush() - with open(path, "wb") as f: - f.write(mp3_data) - - -def prevent_clip(wav, mode='rescale'): - """ - different strategies for avoiding raw clipping. - """ - assert wav.dtype.is_floating_point, "too late for clipping" - if mode == 'rescale': - wav = wav / max(1.01 * wav.abs().max(), 1) - elif mode == 'clamp': - wav = wav.clamp(-0.99, 0.99) - elif mode == 'tanh': - wav = torch.tanh(wav) - else: - raise ValueError(f"Invalid mode {mode}") - return wav - - -def save_audio(wav, path, samplerate, bitrate=320, clip='rescale', - bits_per_sample=16, as_float=False): - """Save audio file, automatically preventing clipping if necessary - based on the given `clip` strategy. If the path ends in `.mp3`, this - will save as mp3 with the given `bitrate`. - """ - wav = prevent_clip(wav, mode=clip) - path = Path(path) - suffix = path.suffix.lower() - if suffix == ".mp3": - encode_mp3(wav, path, samplerate, bitrate) - elif suffix == ".wav": - if as_float: - bits_per_sample = 32 - encoding = 'PCM_F' - else: - encoding = 'PCM_S' - ta.save(str(path), wav, sample_rate=samplerate, - encoding=encoding, bits_per_sample=bits_per_sample) - else: - raise ValueError(f"Invalid suffix for path: {suffix}") diff --git a/demucs/demucs.py b/demucs/demucs.py deleted file mode 100644 index d2c08e7..0000000 --- a/demucs/demucs.py +++ /dev/null @@ -1,459 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import math -import typing as tp - -import julius -import torch -from torch import nn -from torch.nn import functional as F - -from .states import capture_init -from .utils import center_trim, unfold - - -class BLSTM(nn.Module): - """ - BiLSTM with same hidden units as input dim. - If `max_steps` is not None, input will be splitting in overlapping - chunks and the LSTM applied separately on each chunk. - """ - def __init__(self, dim, layers=1, max_steps=None, skip=False): - super().__init__() - assert max_steps is None or max_steps % 4 == 0 - self.max_steps = max_steps - self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim) - self.linear = nn.Linear(2 * dim, dim) - self.skip = skip - - def forward(self, x): - B, C, T = x.shape - y = x - framed = False - if self.max_steps is not None and T > self.max_steps: - width = self.max_steps - stride = width // 2 - frames = unfold(x, width, stride) - nframes = frames.shape[2] - framed = True - x = frames.permute(0, 2, 1, 3).reshape(-1, C, width) - - x = x.permute(2, 0, 1) - - x = self.lstm(x)[0] - x = self.linear(x) - x = x.permute(1, 2, 0) - if framed: - out = [] - frames = x.reshape(B, -1, C, width) - limit = stride // 2 - for k in range(nframes): - if k == 0: - out.append(frames[:, k, :, :-limit]) - elif k == nframes - 1: - out.append(frames[:, k, :, limit:]) - else: - out.append(frames[:, k, :, limit:-limit]) - out = torch.cat(out, -1) - out = out[..., :T] - x = out - if self.skip: - x = x + y - return x - - -def rescale_conv(conv, reference): - """Rescale initial weight scale. It is unclear why it helps but it certainly does. - """ - std = conv.weight.std().detach() - scale = (std / reference)**0.5 - conv.weight.data /= scale - if conv.bias is not None: - conv.bias.data /= scale - - -def rescale_module(module, reference): - for sub in module.modules(): - if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d, nn.Conv2d, nn.ConvTranspose2d)): - rescale_conv(sub, reference) - - -class LayerScale(nn.Module): - """Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf). - This rescales diagonaly residual outputs close to 0 initially, then learnt. - """ - def __init__(self, channels: int, init: float = 0): - super().__init__() - self.scale = nn.Parameter(torch.zeros(channels, requires_grad=True)) - self.scale.data[:] = init - - def forward(self, x): - return self.scale[:, None] * x - - -class DConv(nn.Module): - """ - New residual branches in each encoder layer. - This alternates dilated convolutions, potentially with LSTMs and attention. - Also before entering each residual branch, dimension is projected on a smaller subspace, - e.g. of dim `channels // compress`. - """ - def __init__(self, channels: int, compress: float = 4, depth: int = 2, init: float = 1e-4, - norm=True, attn=False, heads=4, ndecay=4, lstm=False, gelu=True, - kernel=3, dilate=True): - """ - Args: - channels: input/output channels for residual branch. - compress: amount of channel compression inside the branch. - depth: number of layers in the residual branch. Each layer has its own - projection, and potentially LSTM and attention. - init: initial scale for LayerNorm. - norm: use GroupNorm. - attn: use LocalAttention. - heads: number of heads for the LocalAttention. - ndecay: number of decay controls in the LocalAttention. - lstm: use LSTM. - gelu: Use GELU activation. - kernel: kernel size for the (dilated) convolutions. - dilate: if true, use dilation, increasing with the depth. - """ - - super().__init__() - assert kernel % 2 == 1 - self.channels = channels - self.compress = compress - self.depth = abs(depth) - dilate = depth > 0 - - norm_fn: tp.Callable[[int], nn.Module] - norm_fn = lambda d: nn.Identity() # noqa - if norm: - norm_fn = lambda d: nn.GroupNorm(1, d) # noqa - - hidden = int(channels / compress) - - act: tp.Type[nn.Module] - if gelu: - act = nn.GELU - else: - act = nn.ReLU - - self.layers = nn.ModuleList([]) - for d in range(self.depth): - dilation = 2 ** d if dilate else 1 - padding = dilation * (kernel // 2) - mods = [ - nn.Conv1d(channels, hidden, kernel, dilation=dilation, padding=padding), - norm_fn(hidden), act(), - nn.Conv1d(hidden, 2 * channels, 1), - norm_fn(2 * channels), nn.GLU(1), - LayerScale(channels, init), - ] - if attn: - mods.insert(3, LocalState(hidden, heads=heads, ndecay=ndecay)) - if lstm: - mods.insert(3, BLSTM(hidden, layers=2, max_steps=200, skip=True)) - layer = nn.Sequential(*mods) - self.layers.append(layer) - - def forward(self, x): - for layer in self.layers: - x = x + layer(x) - return x - - -class LocalState(nn.Module): - """Local state allows to have attention based only on data (no positional embedding), - but while setting a constraint on the time window (e.g. decaying penalty term). - - Also a failed experiments with trying to provide some frequency based attention. - """ - def __init__(self, channels: int, heads: int = 4, nfreqs: int = 0, ndecay: int = 4): - super().__init__() - assert channels % heads == 0, (channels, heads) - self.heads = heads - self.nfreqs = nfreqs - self.ndecay = ndecay - self.content = nn.Conv1d(channels, channels, 1) - self.query = nn.Conv1d(channels, channels, 1) - self.key = nn.Conv1d(channels, channels, 1) - if nfreqs: - self.query_freqs = nn.Conv1d(channels, heads * nfreqs, 1) - if ndecay: - self.query_decay = nn.Conv1d(channels, heads * ndecay, 1) - # Initialize decay close to zero (there is a sigmoid), for maximum initial window. - self.query_decay.weight.data *= 0.01 - assert self.query_decay.bias is not None # stupid type checker - self.query_decay.bias.data[:] = -2 - self.proj = nn.Conv1d(channels + heads * nfreqs, channels, 1) - - def forward(self, x): - B, C, T = x.shape - heads = self.heads - indexes = torch.arange(T, device=x.device, dtype=x.dtype) - # left index are keys, right index are queries - delta = indexes[:, None] - indexes[None, :] - - queries = self.query(x).view(B, heads, -1, T) - keys = self.key(x).view(B, heads, -1, T) - # t are keys, s are queries - dots = torch.einsum("bhct,bhcs->bhts", keys, queries) - dots /= keys.shape[2]**0.5 - if self.nfreqs: - periods = torch.arange(1, self.nfreqs + 1, device=x.device, dtype=x.dtype) - freq_kernel = torch.cos(2 * math.pi * delta / periods.view(-1, 1, 1)) - freq_q = self.query_freqs(x).view(B, heads, -1, T) / self.nfreqs ** 0.5 - dots += torch.einsum("fts,bhfs->bhts", freq_kernel, freq_q) - if self.ndecay: - decays = torch.arange(1, self.ndecay + 1, device=x.device, dtype=x.dtype) - decay_q = self.query_decay(x).view(B, heads, -1, T) - decay_q = torch.sigmoid(decay_q) / 2 - decay_kernel = - decays.view(-1, 1, 1) * delta.abs() / self.ndecay**0.5 - dots += torch.einsum("fts,bhfs->bhts", decay_kernel, decay_q) - - # Kill self reference. - dots.masked_fill_(torch.eye(T, device=dots.device, dtype=torch.bool), -100) - weights = torch.softmax(dots, dim=2) - - content = self.content(x).view(B, heads, -1, T) - result = torch.einsum("bhts,bhct->bhcs", weights, content) - if self.nfreqs: - time_sig = torch.einsum("bhts,fts->bhfs", weights, freq_kernel) - result = torch.cat([result, time_sig], 2) - result = result.reshape(B, -1, T) - return x + self.proj(result) - - -class Demucs(nn.Module): - @capture_init - def __init__(self, - sources, - # Channels - audio_channels=2, - channels=64, - growth=2., - # Main structure - depth=6, - rewrite=True, - lstm_layers=0, - # Convolutions - kernel_size=8, - stride=4, - context=1, - # Activations - gelu=True, - glu=True, - # Normalization - norm_starts=4, - norm_groups=4, - # DConv residual branch - dconv_mode=1, - dconv_depth=2, - dconv_comp=4, - dconv_attn=4, - dconv_lstm=4, - dconv_init=1e-4, - # Pre/post processing - normalize=True, - resample=True, - # Weight init - rescale=0.1, - # Metadata - samplerate=44100, - segment=4 * 10): - """ - Args: - sources (list[str]): list of source names - audio_channels (int): stereo or mono - channels (int): first convolution channels - depth (int): number of encoder/decoder layers - growth (float): multiply (resp divide) number of channels by that - for each layer of the encoder (resp decoder) - depth (int): number of layers in the encoder and in the decoder. - rewrite (bool): add 1x1 convolution to each layer. - lstm_layers (int): number of lstm layers, 0 = no lstm. Deactivated - by default, as this is now replaced by the smaller and faster small LSTMs - in the DConv branches. - kernel_size (int): kernel size for convolutions - stride (int): stride for convolutions - context (int): kernel size of the convolution in the - decoder before the transposed convolution. If > 1, - will provide some context from neighboring time steps. - gelu: use GELU activation function. - glu (bool): use glu instead of ReLU for the 1x1 rewrite conv. - norm_starts: layer at which group norm starts being used. - decoder layers are numbered in reverse order. - norm_groups: number of groups for group norm. - dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both. - dconv_depth: depth of residual DConv branch. - dconv_comp: compression of DConv branch. - dconv_attn: adds attention layers in DConv branch starting at this layer. - dconv_lstm: adds a LSTM layer in DConv branch starting at this layer. - dconv_init: initial scale for the DConv branch LayerScale. - normalize (bool): normalizes the input audio on the fly, and scales back - the output by the same amount. - resample (bool): upsample x2 the input and downsample /2 the output. - rescale (int): rescale initial weights of convolutions - to get their standard deviation closer to `rescale`. - samplerate (int): stored as meta information for easing - future evaluations of the model. - segment (float): duration of the chunks of audio to ideally evaluate the model on. - This is used by `demucs.apply.apply_model`. - """ - - super().__init__() - self.audio_channels = audio_channels - self.sources = sources - self.kernel_size = kernel_size - self.context = context - self.stride = stride - self.depth = depth - self.resample = resample - self.channels = channels - self.normalize = normalize - self.samplerate = samplerate - self.segment = segment - self.encoder = nn.ModuleList() - self.decoder = nn.ModuleList() - self.skip_scales = nn.ModuleList() - - if glu: - activation = nn.GLU(dim=1) - ch_scale = 2 - else: - activation = nn.ReLU() - ch_scale = 1 - if gelu: - act2 = nn.GELU - else: - act2 = nn.ReLU - - in_channels = audio_channels - padding = 0 - for index in range(depth): - norm_fn = lambda d: nn.Identity() # noqa - if index >= norm_starts: - norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa - - encode = [] - encode += [ - nn.Conv1d(in_channels, channels, kernel_size, stride), - norm_fn(channels), - act2(), - ] - attn = index >= dconv_attn - lstm = index >= dconv_lstm - if dconv_mode & 1: - encode += [DConv(channels, depth=dconv_depth, init=dconv_init, - compress=dconv_comp, attn=attn, lstm=lstm)] - if rewrite: - encode += [ - nn.Conv1d(channels, ch_scale * channels, 1), - norm_fn(ch_scale * channels), activation] - self.encoder.append(nn.Sequential(*encode)) - - decode = [] - if index > 0: - out_channels = in_channels - else: - out_channels = len(self.sources) * audio_channels - if rewrite: - decode += [ - nn.Conv1d(channels, ch_scale * channels, 2 * context + 1, padding=context), - norm_fn(ch_scale * channels), activation] - if dconv_mode & 2: - decode += [DConv(channels, depth=dconv_depth, init=dconv_init, - compress=dconv_comp, attn=attn, lstm=lstm)] - decode += [nn.ConvTranspose1d(channels, out_channels, - kernel_size, stride, padding=padding)] - if index > 0: - decode += [norm_fn(out_channels), act2()] - self.decoder.insert(0, nn.Sequential(*decode)) - in_channels = channels - channels = int(growth * channels) - - channels = in_channels - if lstm_layers: - self.lstm = BLSTM(channels, lstm_layers) - else: - self.lstm = None - - if rescale: - rescale_module(self, reference=rescale) - - def valid_length(self, length): - """ - Return the nearest valid length to use with the model so that - there is no time steps left over in a convolution, e.g. for all - layers, size of the input - kernel_size % stride = 0. - - Note that input are automatically padded if necessary to ensure that the output - has the same length as the input. - """ - if self.resample: - length *= 2 - - for _ in range(self.depth): - length = math.ceil((length - self.kernel_size) / self.stride) + 1 - length = max(1, length) - - for idx in range(self.depth): - length = (length - 1) * self.stride + self.kernel_size - - if self.resample: - length = math.ceil(length / 2) - return int(length) - - def forward(self, mix): - x = mix - length = x.shape[-1] - - if self.normalize: - mono = mix.mean(dim=1, keepdim=True) - mean = mono.mean(dim=-1, keepdim=True) - std = mono.std(dim=-1, keepdim=True) - x = (x - mean) / (1e-5 + std) - else: - mean = 0 - std = 1 - - delta = self.valid_length(length) - length - x = F.pad(x, (delta // 2, delta - delta // 2)) - - if self.resample: - x = julius.resample_frac(x, 1, 2) - - saved = [] - for encode in self.encoder: - x = encode(x) - saved.append(x) - - if self.lstm: - x = self.lstm(x) - - for decode in self.decoder: - skip = saved.pop(-1) - skip = center_trim(skip, x) - x = decode(x + skip) - - if self.resample: - x = julius.resample_frac(x, 2, 1) - x = x * std + mean - x = center_trim(x, length) - x = x.view(x.size(0), len(self.sources), self.audio_channels, x.size(-1)) - return x - - def load_state_dict(self, state, strict=True): - # fix a mismatch with previous generation Demucs models. - for idx in range(self.depth): - for a in ['encoder', 'decoder']: - for b in ['bias', 'weight']: - new = f'{a}.{idx}.3.{b}' - old = f'{a}.{idx}.2.{b}' - if old in state and new not in state: - state[new] = state.pop(old) - super().load_state_dict(state, strict=strict) diff --git a/demucs/distrib.py b/demucs/distrib.py deleted file mode 100644 index b73011a..0000000 --- a/demucs/distrib.py +++ /dev/null @@ -1,100 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -"""Distributed training utilities. -""" -import logging -import pickle - -import numpy as np -import torch -from torch.utils.data.distributed import DistributedSampler -from torch.utils.data import DataLoader, Subset -from torch.nn.parallel.distributed import DistributedDataParallel - -from dora import distrib as dora_distrib - -logger = logging.getLogger(__name__) -rank = 0 -world_size = 1 - - -def init(): - global rank, world_size - if not torch.distributed.is_initialized(): - dora_distrib.init() - rank = dora_distrib.rank() - world_size = dora_distrib.world_size() - - -def average(metrics, count=1.): - if isinstance(metrics, dict): - keys, values = zip(*sorted(metrics.items())) - values = average(values, count) - return dict(zip(keys, values)) - if world_size == 1: - return metrics - tensor = torch.tensor(list(metrics) + [1], device='cuda', dtype=torch.float32) - tensor *= count - torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.SUM) - return (tensor[:-1] / tensor[-1]).cpu().numpy().tolist() - - -def wrap(model): - if world_size == 1: - return model - else: - return DistributedDataParallel( - model, - # find_unused_parameters=True, - device_ids=[torch.cuda.current_device()], - output_device=torch.cuda.current_device()) - - -def barrier(): - if world_size > 1: - torch.distributed.barrier() - - -def share(obj=None, src=0): - if world_size == 1: - return obj - size = torch.empty(1, device='cuda', dtype=torch.long) - if rank == src: - dump = pickle.dumps(obj) - size[0] = len(dump) - torch.distributed.broadcast(size, src=src) - # size variable is now set to the length of pickled obj in all processes - - if rank == src: - buffer = torch.from_numpy(np.frombuffer(dump, dtype=np.uint8).copy()).cuda() - else: - buffer = torch.empty(size[0].item(), device='cuda', dtype=torch.uint8) - torch.distributed.broadcast(buffer, src=src) - # buffer variable is now set to pickled obj in all processes - - if rank != src: - obj = pickle.loads(buffer.cpu().numpy().tobytes()) - logger.debug(f"Shared object of size {len(buffer)}") - return obj - - -def loader(dataset, *args, shuffle=False, klass=DataLoader, **kwargs): - """ - Create a dataloader properly in case of distributed training. - If a gradient is going to be computed you must set `shuffle=True`. - """ - if world_size == 1: - return klass(dataset, *args, shuffle=shuffle, **kwargs) - - if shuffle: - # train means we will compute backward, we use DistributedSampler - sampler = DistributedSampler(dataset) - # We ignore shuffle, DistributedSampler already shuffles - return klass(dataset, *args, **kwargs, sampler=sampler) - else: - # We make a manual shard, as DistributedSampler otherwise replicate some examples - dataset = Subset(dataset, list(range(rank, len(dataset), world_size))) - return klass(dataset, *args, shuffle=shuffle, **kwargs) diff --git a/demucs/ema.py b/demucs/ema.py deleted file mode 100644 index 958c595..0000000 --- a/demucs/ema.py +++ /dev/null @@ -1,66 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -# Inspired from https://github.com/rwightman/pytorch-image-models -from contextlib import contextmanager - -import torch - -from .states import swap_state - - -class ModelEMA: - """ - Perform EMA on a model. You can switch to the EMA weights temporarily - with the `swap` method. - - ema = ModelEMA(model) - with ema.swap(): - # compute valid metrics with averaged model. - """ - def __init__(self, model, decay=0.9999, unbias=True, device='cpu'): - self.decay = decay - self.model = model - self.state = {} - self.count = 0 - self.device = device - self.unbias = unbias - - self._init() - - def _init(self): - for key, val in self.model.state_dict().items(): - if val.dtype != torch.float32: - continue - device = self.device or val.device - if key not in self.state: - self.state[key] = val.detach().to(device, copy=True) - - def update(self): - if self.unbias: - self.count = self.count * self.decay + 1 - w = 1 / self.count - else: - w = 1 - self.decay - for key, val in self.model.state_dict().items(): - if val.dtype != torch.float32: - continue - device = self.device or val.device - self.state[key].mul_(1 - w) - self.state[key].add_(val.detach().to(device), alpha=w) - - @contextmanager - def swap(self): - with swap_state(self.model, self.state): - yield - - def state_dict(self): - return {'state': self.state, 'count': self.count} - - def load_state_dict(self, state): - self.count = state['count'] - for k, v in state['state'].items(): - self.state[k].copy_(v) diff --git a/demucs/evaluate.py b/demucs/evaluate.py deleted file mode 100644 index badb35e..0000000 --- a/demucs/evaluate.py +++ /dev/null @@ -1,173 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -"""Test time evaluation, either using the original SDR from [Vincent et al. 2006] -or the newest SDR definition from the MDX 2021 competition (this one will -be reported as `nsdr` for `new sdr`). -""" - -from concurrent import futures -import logging - -from dora.log import LogProgress -import numpy as np -import musdb -import museval -import torch as th - -from .apply import apply_model -from .audio import convert_audio, save_audio -from . import distrib -from .utils import DummyPoolExecutor - - -logger = logging.getLogger(__name__) - - -def new_sdr(references, estimates): - """ - Compute the SDR according to the MDX challenge definition. - Adapted from AIcrowd/music-demixing-challenge-starter-kit (MIT license) - """ - assert references.dim() == 4 - assert estimates.dim() == 4 - delta = 1e-7 # avoid numerical errors - num = th.sum(th.square(references), dim=(2, 3)) - den = th.sum(th.square(references - estimates), dim=(2, 3)) - num += delta - den += delta - scores = 10 * th.log10(num / den) - return scores - - -def eval_track(references, estimates, win, hop, compute_sdr=True): - references = references.transpose(1, 2).double() - estimates = estimates.transpose(1, 2).double() - - new_scores = new_sdr(references.cpu()[None], estimates.cpu()[None])[0] - - if not compute_sdr: - return None, new_scores - else: - references = references.numpy() - estimates = estimates.numpy() - scores = museval.metrics.bss_eval( - references, estimates, - compute_permutation=False, - window=win, - hop=hop, - framewise_filters=False, - bsseval_sources_version=False)[:-1] - return scores, new_scores - - -def evaluate(solver, compute_sdr=False): - """ - Evaluate model using museval. - `new_only` means using only the MDX definition of the SDR, which is much faster to evaluate. - """ - - args = solver.args - - output_dir = solver.folder / "results" - output_dir.mkdir(exist_ok=True, parents=True) - json_folder = solver.folder / "results/test" - json_folder.mkdir(exist_ok=True, parents=True) - - # we load tracks from the original musdb set - if args.test.nonhq is None: - test_set = musdb.DB(args.dset.musdb, subsets=["test"], is_wav=True) - else: - test_set = musdb.DB(args.test.nonhq, subsets=["test"], is_wav=False) - src_rate = args.dset.musdb_samplerate - - eval_device = 'cpu' - - model = solver.model - win = int(1. * model.samplerate) - hop = int(1. * model.samplerate) - - indexes = range(distrib.rank, len(test_set), distrib.world_size) - indexes = LogProgress(logger, indexes, updates=args.misc.num_prints, - name='Eval') - pendings = [] - - pool = futures.ProcessPoolExecutor if args.test.workers else DummyPoolExecutor - with pool(args.test.workers) as pool: - for index in indexes: - track = test_set.tracks[index] - - mix = th.from_numpy(track.audio).t().float() - if mix.dim() == 1: - mix = mix[None] - mix = mix.to(solver.device) - ref = mix.mean(dim=0) # mono mixture - mix = (mix - ref.mean()) / ref.std() - mix = convert_audio(mix, src_rate, model.samplerate, model.audio_channels) - estimates = apply_model(model, mix[None], - shifts=args.test.shifts, split=args.test.split, - overlap=args.test.overlap)[0] - estimates = estimates * ref.std() + ref.mean() - estimates = estimates.to(eval_device) - - references = th.stack( - [th.from_numpy(track.targets[name].audio).t() for name in model.sources]) - if references.dim() == 2: - references = references[:, None] - references = references.to(eval_device) - references = convert_audio(references, src_rate, - model.samplerate, model.audio_channels) - if args.test.save: - folder = solver.folder / "wav" / track.name - folder.mkdir(exist_ok=True, parents=True) - for name, estimate in zip(model.sources, estimates): - save_audio(estimate.cpu(), folder / (name + ".mp3"), model.samplerate) - - pendings.append((track.name, pool.submit( - eval_track, references, estimates, win=win, hop=hop, compute_sdr=compute_sdr))) - - pendings = LogProgress(logger, pendings, updates=args.misc.num_prints, - name='Eval (BSS)') - tracks = {} - for track_name, pending in pendings: - pending = pending.result() - scores, nsdrs = pending - tracks[track_name] = {} - for idx, target in enumerate(model.sources): - tracks[track_name][target] = {'nsdr': [float(nsdrs[idx])]} - if scores is not None: - (sdr, isr, sir, sar) = scores - for idx, target in enumerate(model.sources): - values = { - "SDR": sdr[idx].tolist(), - "SIR": sir[idx].tolist(), - "ISR": isr[idx].tolist(), - "SAR": sar[idx].tolist() - } - tracks[track_name][target].update(values) - - all_tracks = {} - for src in range(distrib.world_size): - all_tracks.update(distrib.share(tracks, src)) - - result = {} - metric_names = next(iter(all_tracks.values()))[model.sources[0]] - for metric_name in metric_names: - avg = 0 - avg_of_medians = 0 - for source in model.sources: - medians = [ - np.nanmedian(all_tracks[track][source][metric_name]) - for track in all_tracks.keys()] - mean = np.mean(medians) - median = np.median(medians) - result[metric_name.lower() + "_" + source] = mean - result[metric_name.lower() + "_med" + "_" + source] = median - avg += mean / len(model.sources) - avg_of_medians += median / len(model.sources) - result[metric_name.lower()] = avg - result[metric_name.lower() + "_med"] = avg_of_medians - return result diff --git a/demucs/hdemucs.py b/demucs/hdemucs.py deleted file mode 100644 index 864fd3f..0000000 --- a/demucs/hdemucs.py +++ /dev/null @@ -1,761 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -""" -This code contains the spectrogram and Hybrid version of Demucs. -""" -from copy import deepcopy -import math - -from openunmix.filtering import wiener -import torch -from torch import nn -from torch.nn import functional as F - -from .demucs import DConv, rescale_module -from .states import capture_init -from .spec import spectro, ispectro - - -class ScaledEmbedding(nn.Module): - """ - Boost learning rate for embeddings (with `scale`). - Also, can make embeddings continuous with `smooth`. - """ - def __init__(self, num_embeddings: int, embedding_dim: int, - scale: float = 10., smooth=False): - super().__init__() - self.embedding = nn.Embedding(num_embeddings, embedding_dim) - if smooth: - weight = torch.cumsum(self.embedding.weight.data, dim=0) - # when summing gaussian, overscale raises as sqrt(n), so we nornalize by that. - weight = weight / torch.arange(1, num_embeddings + 1).to(weight).sqrt()[:, None] - self.embedding.weight.data[:] = weight - self.embedding.weight.data /= scale - self.scale = scale - - @property - def weight(self): - return self.embedding.weight * self.scale - - def forward(self, x): - out = self.embedding(x) * self.scale - return out - - -class HEncLayer(nn.Module): - def __init__(self, chin, chout, kernel_size=8, stride=4, norm_groups=1, empty=False, - freq=True, dconv=True, norm=True, context=0, dconv_kw={}, pad=True, - rewrite=True): - """Encoder layer. This used both by the time and the frequency branch. - - Args: - chin: number of input channels. - chout: number of output channels. - norm_groups: number of groups for group norm. - empty: used to make a layer with just the first conv. this is used - before merging the time and freq. branches. - freq: this is acting on frequencies. - dconv: insert DConv residual branches. - norm: use GroupNorm. - context: context size for the 1x1 conv. - dconv_kw: list of kwargs for the DConv class. - pad: pad the input. Padding is done so that the output size is - always the input size / stride. - rewrite: add 1x1 conv at the end of the layer. - """ - super().__init__() - norm_fn = lambda d: nn.Identity() # noqa - if norm: - norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa - if pad: - pad = kernel_size // 4 - else: - pad = 0 - klass = nn.Conv1d - self.freq = freq - self.kernel_size = kernel_size - self.stride = stride - self.empty = empty - self.norm = norm - self.pad = pad - if freq: - kernel_size = [kernel_size, 1] - stride = [stride, 1] - pad = [pad, 0] - klass = nn.Conv2d - self.conv = klass(chin, chout, kernel_size, stride, pad) - if self.empty: - return - self.norm1 = norm_fn(chout) - self.rewrite = None - if rewrite: - self.rewrite = klass(chout, 2 * chout, 1 + 2 * context, 1, context) - self.norm2 = norm_fn(2 * chout) - - self.dconv = None - if dconv: - self.dconv = DConv(chout, **dconv_kw) - - def forward(self, x, inject=None): - """ - `inject` is used to inject the result from the time branch into the frequency branch, - when both have the same stride. - """ - if not self.freq and x.dim() == 4: - B, C, Fr, T = x.shape - x = x.view(B, -1, T) - - if not self.freq: - le = x.shape[-1] - if not le % self.stride == 0: - x = F.pad(x, (0, self.stride - (le % self.stride))) - y = self.conv(x) - if self.empty: - return y - if inject is not None: - assert inject.shape[-1] == y.shape[-1], (inject.shape, y.shape) - if inject.dim() == 3 and y.dim() == 4: - inject = inject[:, :, None] - y = y + inject - y = F.gelu(self.norm1(y)) - if self.dconv: - if self.freq: - B, C, Fr, T = y.shape - y = y.permute(0, 2, 1, 3).reshape(-1, C, T) - y = self.dconv(y) - if self.freq: - y = y.view(B, Fr, C, T).permute(0, 2, 1, 3) - if self.rewrite: - z = self.norm2(self.rewrite(y)) - z = F.glu(z, dim=1) - else: - z = y - return z - - -class MultiWrap(nn.Module): - """ - Takes one layer and replicate it N times. each replica will act - on a frequency band. All is done so that if the N replica have the same weights, - then this is exactly equivalent to applying the original module on all frequencies. - - This is a bit over-engineered to avoid edge artifacts when splitting - the frequency bands, but it is possible the naive implementation would work as well... - """ - def __init__(self, layer, split_ratios): - """ - Args: - layer: module to clone, must be either HEncLayer or HDecLayer. - split_ratios: list of float indicating which ratio to keep for each band. - """ - super().__init__() - self.split_ratios = split_ratios - self.layers = nn.ModuleList() - self.conv = isinstance(layer, HEncLayer) - assert not layer.norm - assert layer.freq - assert layer.pad - if not self.conv: - assert not layer.context_freq - for k in range(len(split_ratios) + 1): - lay = deepcopy(layer) - if self.conv: - lay.conv.padding = (0, 0) - else: - lay.pad = False - for m in lay.modules(): - if hasattr(m, 'reset_parameters'): - m.reset_parameters() - self.layers.append(lay) - - def forward(self, x, skip=None, length=None): - B, C, Fr, T = x.shape - - ratios = list(self.split_ratios) + [1] - start = 0 - outs = [] - for ratio, layer in zip(ratios, self.layers): - if self.conv: - pad = layer.kernel_size // 4 - if ratio == 1: - limit = Fr - frames = -1 - else: - limit = int(round(Fr * ratio)) - le = limit - start - if start == 0: - le += pad - frames = round((le - layer.kernel_size) / layer.stride + 1) - limit = start + (frames - 1) * layer.stride + layer.kernel_size - if start == 0: - limit -= pad - assert limit - start > 0, (limit, start) - assert limit <= Fr, (limit, Fr) - y = x[:, :, start:limit, :] - if start == 0: - y = F.pad(y, (0, 0, pad, 0)) - if ratio == 1: - y = F.pad(y, (0, 0, 0, pad)) - outs.append(layer(y)) - start = limit - layer.kernel_size + layer.stride - else: - if ratio == 1: - limit = Fr - else: - limit = int(round(Fr * ratio)) - last = layer.last - layer.last = True - - y = x[:, :, start:limit] - s = skip[:, :, start:limit] - out, _ = layer(y, s, None) - if outs: - outs[-1][:, :, -layer.stride:] += ( - out[:, :, :layer.stride] - layer.conv_tr.bias.view(1, -1, 1, 1)) - out = out[:, :, layer.stride:] - if ratio == 1: - out = out[:, :, :-layer.stride // 2, :] - if start == 0: - out = out[:, :, layer.stride // 2:, :] - outs.append(out) - layer.last = last - start = limit - out = torch.cat(outs, dim=2) - if not self.conv and not last: - out = F.gelu(out) - if self.conv: - return out - else: - return out, None - - -class HDecLayer(nn.Module): - def __init__(self, chin, chout, last=False, kernel_size=8, stride=4, norm_groups=1, empty=False, - freq=True, dconv=True, norm=True, context=1, dconv_kw={}, pad=True, - context_freq=True, rewrite=True): - """ - Same as HEncLayer but for decoder. See `HEncLayer` for documentation. - """ - super().__init__() - norm_fn = lambda d: nn.Identity() # noqa - if norm: - norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa - if pad: - pad = kernel_size // 4 - else: - pad = 0 - self.pad = pad - self.last = last - self.freq = freq - self.chin = chin - self.empty = empty - self.stride = stride - self.kernel_size = kernel_size - self.norm = norm - self.context_freq = context_freq - klass = nn.Conv1d - klass_tr = nn.ConvTranspose1d - if freq: - kernel_size = [kernel_size, 1] - stride = [stride, 1] - klass = nn.Conv2d - klass_tr = nn.ConvTranspose2d - self.conv_tr = klass_tr(chin, chout, kernel_size, stride) - self.norm2 = norm_fn(chout) - if self.empty: - return - self.rewrite = None - if rewrite: - if context_freq: - self.rewrite = klass(chin, 2 * chin, 1 + 2 * context, 1, context) - else: - self.rewrite = klass(chin, 2 * chin, [1, 1 + 2 * context], 1, - [0, context]) - self.norm1 = norm_fn(2 * chin) - - self.dconv = None - if dconv: - self.dconv = DConv(chin, **dconv_kw) - - def forward(self, x, skip, length): - if self.freq and x.dim() == 3: - B, C, T = x.shape - x = x.view(B, self.chin, -1, T) - - if not self.empty: - x = x + skip - - if self.rewrite: - y = F.glu(self.norm1(self.rewrite(x)), dim=1) - else: - y = x - if self.dconv: - if self.freq: - B, C, Fr, T = y.shape - y = y.permute(0, 2, 1, 3).reshape(-1, C, T) - y = self.dconv(y) - if self.freq: - y = y.view(B, Fr, C, T).permute(0, 2, 1, 3) - else: - y = x - assert skip is None - z = self.norm2(self.conv_tr(y)) - if self.freq: - if self.pad: - z = z[..., self.pad:-self.pad, :] - else: - z = z[..., self.pad:self.pad + length] - assert z.shape[-1] == length, (z.shape[-1], length) - if not self.last: - z = F.gelu(z) - return z, y - - -class HDemucs(nn.Module): - """ - Spectrogram and hybrid Demucs model. - The spectrogram model has the same structure as Demucs, except the first few layers are over the - frequency axis, until there is only 1 frequency, and then it moves to time convolutions. - Frequency layers can still access information across time steps thanks to the DConv residual. - - Hybrid model have a parallel time branch. At some layer, the time branch has the same stride - as the frequency branch and then the two are combined. The opposite happens in the decoder. - - Models can either use naive iSTFT from masking, Wiener filtering ([Ulhih et al. 2017]), - or complex as channels (CaC) [Choi et al. 2020]. Wiener filtering is based on - Open Unmix implementation [Stoter et al. 2019]. - - The loss is always on the temporal domain, by backpropagating through the above - output methods and iSTFT. This allows to define hybrid models nicely. However, this breaks - a bit Wiener filtering, as doing more iteration at test time will change the spectrogram - contribution, without changing the one from the waveform, which will lead to worse performance. - I tried using the residual option in OpenUnmix Wiener implementation, but it didn't improve. - CaC on the other hand provides similar performance for hybrid, and works naturally with - hybrid models. - - This model also uses frequency embeddings are used to improve efficiency on convolutions - over the freq. axis, following [Isik et al. 2020] (https://arxiv.org/pdf/2008.04470.pdf). - - Unlike classic Demucs, there is no resampling here, and normalization is always applied. - """ - @capture_init - def __init__(self, - sources, - # Channels - audio_channels=2, - channels=48, - channels_time=None, - growth=2, - # STFT - nfft=4096, - wiener_iters=0, - end_iters=0, - wiener_residual=False, - cac=True, - # Main structure - depth=6, - rewrite=True, - hybrid=True, - hybrid_old=False, - # Frequency branch - multi_freqs=None, - multi_freqs_depth=2, - freq_emb=0.2, - emb_scale=10, - emb_smooth=True, - # Convolutions - kernel_size=8, - time_stride=2, - stride=4, - context=1, - context_enc=0, - # Normalization - norm_starts=4, - norm_groups=4, - # DConv residual branch - dconv_mode=1, - dconv_depth=2, - dconv_comp=4, - dconv_attn=4, - dconv_lstm=4, - dconv_init=1e-4, - # Weight init - rescale=0.1, - # Metadata - samplerate=44100, - segment=4 * 10): - """ - Args: - sources (list[str]): list of source names. - audio_channels (int): input/output audio channels. - channels (int): initial number of hidden channels. - channels_time: if not None, use a different `channels` value for the time branch. - growth: increase the number of hidden channels by this factor at each layer. - nfft: number of fft bins. Note that changing this require careful computation of - various shape parameters and will not work out of the box for hybrid models. - wiener_iters: when using Wiener filtering, number of iterations at test time. - end_iters: same but at train time. For a hybrid model, must be equal to `wiener_iters`. - wiener_residual: add residual source before wiener filtering. - cac: uses complex as channels, i.e. complex numbers are 2 channels each - in input and output. no further processing is done before ISTFT. - depth (int): number of layers in the encoder and in the decoder. - rewrite (bool): add 1x1 convolution to each layer. - hybrid (bool): make a hybrid time/frequency domain, otherwise frequency only. - hybrid_old: some models trained for MDX had a padding bug. This replicates - this bug to avoid retraining them. - multi_freqs: list of frequency ratios for splitting frequency bands with `MultiWrap`. - multi_freqs_depth: how many layers to wrap with `MultiWrap`. Only the outermost - layers will be wrapped. - freq_emb: add frequency embedding after the first frequency layer if > 0, - the actual value controls the weight of the embedding. - emb_scale: equivalent to scaling the embedding learning rate - emb_smooth: initialize the embedding with a smooth one (with respect to frequencies). - kernel_size: kernel_size for encoder and decoder layers. - stride: stride for encoder and decoder layers. - time_stride: stride for the final time layer, after the merge. - context: context for 1x1 conv in the decoder. - context_enc: context for 1x1 conv in the encoder. - norm_starts: layer at which group norm starts being used. - decoder layers are numbered in reverse order. - norm_groups: number of groups for group norm. - dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both. - dconv_depth: depth of residual DConv branch. - dconv_comp: compression of DConv branch. - dconv_attn: adds attention layers in DConv branch starting at this layer. - dconv_lstm: adds a LSTM layer in DConv branch starting at this layer. - dconv_init: initial scale for the DConv branch LayerScale. - rescale: weight recaling trick - - """ - super().__init__() - self.cac = cac - self.wiener_residual = wiener_residual - self.audio_channels = audio_channels - self.sources = sources - self.kernel_size = kernel_size - self.context = context - self.stride = stride - self.depth = depth - self.channels = channels - self.samplerate = samplerate - self.segment = segment - - self.nfft = nfft - self.hop_length = nfft // 4 - self.wiener_iters = wiener_iters - self.end_iters = end_iters - self.freq_emb = None - self.hybrid = hybrid - self.hybrid_old = hybrid_old - if hybrid_old: - assert hybrid, "hybrid_old must come with hybrid=True" - if hybrid: - assert wiener_iters == end_iters - - self.encoder = nn.ModuleList() - self.decoder = nn.ModuleList() - - if hybrid: - self.tencoder = nn.ModuleList() - self.tdecoder = nn.ModuleList() - - chin = audio_channels - chin_z = chin # number of channels for the freq branch - if self.cac: - chin_z *= 2 - chout = channels_time or channels - chout_z = channels - freqs = nfft // 2 - - for index in range(depth): - lstm = index >= dconv_lstm - attn = index >= dconv_attn - norm = index >= norm_starts - freq = freqs > 1 - stri = stride - ker = kernel_size - if not freq: - assert freqs == 1 - ker = time_stride * 2 - stri = time_stride - - pad = True - last_freq = False - if freq and freqs <= kernel_size: - ker = freqs - pad = False - last_freq = True - - kw = { - 'kernel_size': ker, - 'stride': stri, - 'freq': freq, - 'pad': pad, - 'norm': norm, - 'rewrite': rewrite, - 'norm_groups': norm_groups, - 'dconv_kw': { - 'lstm': lstm, - 'attn': attn, - 'depth': dconv_depth, - 'compress': dconv_comp, - 'init': dconv_init, - 'gelu': True, - } - } - kwt = dict(kw) - kwt['freq'] = 0 - kwt['kernel_size'] = kernel_size - kwt['stride'] = stride - kwt['pad'] = True - kw_dec = dict(kw) - multi = False - if multi_freqs and index < multi_freqs_depth: - multi = True - kw_dec['context_freq'] = False - - if last_freq: - chout_z = max(chout, chout_z) - chout = chout_z - - enc = HEncLayer(chin_z, chout_z, - dconv=dconv_mode & 1, context=context_enc, **kw) - if hybrid and freq: - tenc = HEncLayer(chin, chout, dconv=dconv_mode & 1, context=context_enc, - empty=last_freq, **kwt) - self.tencoder.append(tenc) - - if multi: - enc = MultiWrap(enc, multi_freqs) - self.encoder.append(enc) - if index == 0: - chin = self.audio_channels * len(self.sources) - chin_z = chin - if self.cac: - chin_z *= 2 - dec = HDecLayer(chout_z, chin_z, dconv=dconv_mode & 2, - last=index == 0, context=context, **kw_dec) - if multi: - dec = MultiWrap(dec, multi_freqs) - if hybrid and freq: - tdec = HDecLayer(chout, chin, dconv=dconv_mode & 2, empty=last_freq, - last=index == 0, context=context, **kwt) - self.tdecoder.insert(0, tdec) - self.decoder.insert(0, dec) - - chin = chout - chin_z = chout_z - chout = int(growth * chout) - chout_z = int(growth * chout_z) - if freq: - if freqs <= kernel_size: - freqs = 1 - else: - freqs //= stride - if index == 0 and freq_emb: - self.freq_emb = ScaledEmbedding( - freqs, chin_z, smooth=emb_smooth, scale=emb_scale) - self.freq_emb_scale = freq_emb - - if rescale: - rescale_module(self, reference=rescale) - - def _spec(self, x): - hl = self.hop_length - nfft = self.nfft - x0 = x # noqa - - if self.hybrid: - # We re-pad the signal in order to keep the property - # that the size of the output is exactly the size of the input - # divided by the stride (here hop_length), when divisible. - # This is achieved by padding by 1/4th of the kernel size (here nfft). - # which is not supported by torch.stft. - # Having all convolution operations follow this convention allow to easily - # align the time and frequency branches later on. - assert hl == nfft // 4 - le = int(math.ceil(x.shape[-1] / hl)) - pad = hl // 2 * 3 - if not self.hybrid_old: - x = F.pad(x, (pad, pad + le * hl - x.shape[-1]), mode='reflect') - else: - x = F.pad(x, (pad, pad + le * hl - x.shape[-1])) - - z = spectro(x, nfft, hl)[..., :-1, :] - if self.hybrid: - assert z.shape[-1] == le + 4, (z.shape, x.shape, le) - z = z[..., 2:2+le] - return z - - def _ispec(self, z, length=None, scale=0): - hl = self.hop_length // (4 ** scale) - z = F.pad(z, (0, 0, 0, 1)) - if self.hybrid: - z = F.pad(z, (2, 2)) - pad = hl // 2 * 3 - if not self.hybrid_old: - le = hl * int(math.ceil(length / hl)) + 2 * pad - else: - le = hl * int(math.ceil(length / hl)) - x = ispectro(z, hl, length=le) - if not self.hybrid_old: - x = x[..., pad:pad + length] - else: - x = x[..., :length] - else: - x = ispectro(z, hl, length) - return x - - def _magnitude(self, z): - # return the magnitude of the spectrogram, except when cac is True, - # in which case we just move the complex dimension to the channel one. - if self.cac: - B, C, Fr, T = z.shape - m = torch.view_as_real(z).permute(0, 1, 4, 2, 3) - m = m.reshape(B, C * 2, Fr, T) - else: - m = z.abs() - return m - - def _mask(self, z, m): - # Apply masking given the mixture spectrogram `z` and the estimated mask `m`. - # If `cac` is True, `m` is actually a full spectrogram and `z` is ignored. - niters = self.wiener_iters - if self.cac: - B, S, C, Fr, T = m.shape - out = m.view(B, S, -1, 2, Fr, T).permute(0, 1, 2, 4, 5, 3) - out = torch.view_as_complex(out.contiguous()) - return out - if self.training: - niters = self.end_iters - if niters < 0: - z = z[:, None] - return z / (1e-8 + z.abs()) * m - else: - return self._wiener(m, z, niters) - - def _wiener(self, mag_out, mix_stft, niters): - # apply wiener filtering from OpenUnmix. - init = mix_stft.dtype - wiener_win_len = 300 - residual = self.wiener_residual - - B, S, C, Fq, T = mag_out.shape - mag_out = mag_out.permute(0, 4, 3, 2, 1) - mix_stft = torch.view_as_real(mix_stft.permute(0, 3, 2, 1)) - - outs = [] - for sample in range(B): - pos = 0 - out = [] - for pos in range(0, T, wiener_win_len): - frame = slice(pos, pos + wiener_win_len) - z_out = wiener( - mag_out[sample, frame], mix_stft[sample, frame], niters, - residual=residual) - out.append(z_out.transpose(-1, -2)) - outs.append(torch.cat(out, dim=0)) - out = torch.view_as_complex(torch.stack(outs, 0)) - out = out.permute(0, 4, 3, 2, 1).contiguous() - if residual: - out = out[:, :-1] - assert list(out.shape) == [B, S, C, Fq, T] - return out.to(init) - - def forward(self, mix): - x = mix - length = x.shape[-1] - - z = self._spec(mix) - mag = self._magnitude(z) - x = mag - - B, C, Fq, T = x.shape - - # unlike previous Demucs, we always normalize because it is easier. - mean = x.mean(dim=(1, 2, 3), keepdim=True) - std = x.std(dim=(1, 2, 3), keepdim=True) - x = (x - mean) / (1e-5 + std) - # x will be the freq. branch input. - - if self.hybrid: - # Prepare the time branch input. - xt = mix - meant = xt.mean(dim=(1, 2), keepdim=True) - stdt = xt.std(dim=(1, 2), keepdim=True) - xt = (xt - meant) / (1e-5 + stdt) - - # okay, this is a giant mess I know... - saved = [] # skip connections, freq. - saved_t = [] # skip connections, time. - lengths = [] # saved lengths to properly remove padding, freq branch. - lengths_t = [] # saved lengths for time branch. - for idx, encode in enumerate(self.encoder): - lengths.append(x.shape[-1]) - inject = None - if self.hybrid and idx < len(self.tencoder): - # we have not yet merged branches. - lengths_t.append(xt.shape[-1]) - tenc = self.tencoder[idx] - xt = tenc(xt) - if not tenc.empty: - # save for skip connection - saved_t.append(xt) - else: - # tenc contains just the first conv., so that now time and freq. - # branches have the same shape and can be merged. - inject = xt - x = encode(x, inject) - if idx == 0 and self.freq_emb is not None: - # add frequency embedding to allow for non equivariant convolutions - # over the frequency axis. - frs = torch.arange(x.shape[-2], device=x.device) - emb = self.freq_emb(frs).t()[None, :, :, None].expand_as(x) - x = x + self.freq_emb_scale * emb - - saved.append(x) - - x = torch.zeros_like(x) - if self.hybrid: - xt = torch.zeros_like(x) - # initialize everything to zero (signal will go through u-net skips). - - for idx, decode in enumerate(self.decoder): - skip = saved.pop(-1) - x, pre = decode(x, skip, lengths.pop(-1)) - # `pre` contains the output just before final transposed convolution, - # which is used when the freq. and time branch separate. - - if self.hybrid: - offset = self.depth - len(self.tdecoder) - if self.hybrid and idx >= offset: - tdec = self.tdecoder[idx - offset] - length_t = lengths_t.pop(-1) - if tdec.empty: - assert pre.shape[2] == 1, pre.shape - pre = pre[:, :, 0] - xt, _ = tdec(pre, None, length_t) - else: - skip = saved_t.pop(-1) - xt, _ = tdec(xt, skip, length_t) - - # Let's make sure we used all stored skip connections. - assert len(saved) == 0 - assert len(lengths_t) == 0 - assert len(saved_t) == 0 - - S = len(self.sources) - x = x.view(B, S, -1, Fq, T) - x = x * std[:, None] + mean[:, None] - - zout = self._mask(z, x) - x = self._ispec(zout, length) - - if self.hybrid: - xt = xt.view(B, S, -1, length) - xt = xt * stdt[:, None] + meant[:, None] - x = xt + x - return x diff --git a/demucs/pretrained.py b/demucs/pretrained.py deleted file mode 100644 index 1c976c6..0000000 --- a/demucs/pretrained.py +++ /dev/null @@ -1,66 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -"""Loading pretrained models. -""" - -import logging -from pathlib import Path -import typing as tp - -from dora.log import fatal - -from .hdemucs import HDemucs -from .repo import RemoteRepo, LocalRepo, ModelOnlyRepo, BagOnlyRepo, AnyModelRepo, ModelLoadingError # noqa - -logger = logging.getLogger(__name__) -ROOT_URL = "https://dl.fbaipublicfiles.com/demucs/mdx_final/" -REMOTE_ROOT = Path(__file__).parent / 'remote' - -SOURCES = ["drums", "bass", "other", "vocals"] - - -def demucs_unittest(): - model = HDemucs(channels=4, sources=SOURCES) - return model - - -def add_model_flags(parser): - group = parser.add_mutually_exclusive_group(required=False) - group.add_argument("-s", "--sig", help="Locally trained XP signature.") - group.add_argument("-n", "--name", default="mdx_extra_q", - help="Pretrained model name or signature. Default is mdx_extra_q.") - parser.add_argument("--repo", type=Path, - help="Folder containing all pre-trained models for use with -n.") - - -def get_model(name: str, - repo: tp.Optional[Path] = None): - """`name` must be a bag of models name or a pretrained signature - from the remote AWS model repo or the specified local repo if `repo` is not None. - """ - if name == 'demucs_unittest': - return demucs_unittest() - model_repo: ModelOnlyRepo - if repo is None: - remote_files = [line.strip() - for line in (REMOTE_ROOT / 'files.txt').read_text().split('\n') - if line.strip()] - model_repo = RemoteRepo(ROOT_URL, remote_files) - bag_repo = BagOnlyRepo(REMOTE_ROOT, model_repo) - else: - if not repo.is_dir(): - fatal(f"{repo} must exist and be a directory.") - model_repo = LocalRepo(repo) - bag_repo = BagOnlyRepo(repo, model_repo) - any_repo = AnyModelRepo(model_repo, bag_repo) - return any_repo.get_model(name) - - -def get_model_from_args(args): - """ - Load local model package or pre-trained model. - """ - return get_model(name=args.name, repo=args.repo) diff --git a/demucs/repo.py b/demucs/repo.py deleted file mode 100644 index f79c532..0000000 --- a/demucs/repo.py +++ /dev/null @@ -1,153 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -"""Represents a model repository, including pre-trained models and bags of models. -A repo can either be the main remote repository stored in AWS, or a local repository -with your own models. -""" - -from hashlib import sha256 -from pathlib import Path -import typing as tp - -import torch -import yaml - -from .apply import BagOfModels, Model -from .states import load_model - - -AnyModel = tp.Union[Model, BagOfModels] - - -class ModelLoadingError(RuntimeError): - pass - - -def check_checksum(path: Path, checksum: str): - sha = sha256() - with open(path, 'rb') as file: - while True: - buf = file.read(2**20) - if not buf: - break - sha.update(buf) - actual_checksum = sha.hexdigest()[:len(checksum)] - if actual_checksum != checksum: - raise ModelLoadingError(f'Invalid checksum for file {path}, ' - f'expected {checksum} but got {actual_checksum}') - - -class ModelOnlyRepo: - """Base class for all model only repos. - """ - def has_model(self, sig: str) -> bool: - raise NotImplementedError() - - def get_model(self, sig: str) -> Model: - raise NotImplementedError() - - -class RemoteRepo(ModelOnlyRepo): - def __init__(self, root_url: str, remote_files: tp.List[str]): - if not root_url.endswith('/'): - root_url += '/' - self._models: tp.Dict[str, str] = {} - for file in remote_files: - sig, checksum = file.split('.')[0].split('-') - assert sig not in self._models - self._models[sig] = root_url + file - - def has_model(self, sig: str) -> bool: - return sig in self._models - - def get_model(self, sig: str) -> Model: - try: - url = self._models[sig] - except KeyError: - raise ModelLoadingError(f'Could not find a pre-trained model with signature {sig}.') - pkg = torch.hub.load_state_dict_from_url(url, map_location='cpu', check_hash=True) - return load_model(pkg) - - -class LocalRepo(ModelOnlyRepo): - def __init__(self, root: Path): - self.root = root - self.scan() - - def scan(self): - self._models = {} - self._checksums = {} - for file in self.root.iterdir(): - if file.suffix == '.th': - if '-' in file.stem: - xp_sig, checksum = file.stem.split('-') - self._checksums[xp_sig] = checksum - else: - xp_sig = file.stem - if xp_sig in self._models: - raise ModelLoadingError( - f'Duplicate pre-trained model exist for signature {xp_sig}. ' - 'Please delete all but one.') - self._models[xp_sig] = file - - def has_model(self, sig: str) -> bool: - return sig in self._models - - def get_model(self, sig: str) -> Model: - try: - file = self._models[sig] - except KeyError: - raise ModelLoadingError(f'Could not find pre-trained model with signature {sig}.') - if sig in self._checksums: - check_checksum(file, self._checksums[sig]) - return load_model(file) - - -class BagOnlyRepo: - """Handles only YAML files containing bag of models, leaving the actual - model loading to some Repo. - """ - def __init__(self, root: Path, model_repo: ModelOnlyRepo): - self.root = root - self.model_repo = model_repo - self.scan() - - def scan(self): - self._bags = {} - for file in self.root.iterdir(): - if file.suffix == '.yaml': - self._bags[file.stem] = file - - def has_model(self, name: str) -> bool: - return name in self._bags - - def get_model(self, name: str) -> BagOfModels: - try: - yaml_file = self._bags[name] - except KeyError: - raise ModelLoadingError(f'{name} is neither a single pre-trained model or ' - 'a bag of models.') - bag = yaml.safe_load(open(yaml_file)) - signatures = bag['models'] - models = [self.model_repo.get_model(sig) for sig in signatures] - weights = bag.get('weights') - segment = bag.get('segment') - return BagOfModels(models, weights, segment) - - -class AnyModelRepo: - def __init__(self, model_repo: ModelOnlyRepo, bag_repo: BagOnlyRepo): - self.model_repo = model_repo - self.bag_repo = bag_repo - - def has_model(self, name_or_sig: str) -> bool: - return self.model_repo.has_model(name_or_sig) or self.bag_repo.has_model(name_or_sig) - - def get_model(self, name_or_sig: str) -> AnyModel: - if self.model_repo.has_model(name_or_sig): - return self.model_repo.get_model(name_or_sig) - else: - return self.bag_repo.get_model(name_or_sig) diff --git a/demucs/separate.py b/demucs/separate.py deleted file mode 100644 index 1554ce3..0000000 --- a/demucs/separate.py +++ /dev/null @@ -1,186 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import argparse -import sys -from pathlib import Path -import subprocess - -from dora.log import fatal -import torch as th -import torchaudio as ta - -from .apply import apply_model, BagOfModels -from .audio import AudioFile, convert_audio, save_audio -from .pretrained import get_model_from_args, add_model_flags, ModelLoadingError - - -def load_track(track, audio_channels, samplerate): - errors = {} - wav = None - - try: - wav = AudioFile(track).read( - streams=0, - samplerate=samplerate, - channels=audio_channels) - except FileNotFoundError: - errors['ffmpeg'] = 'Ffmpeg is not installed.' - except subprocess.CalledProcessError: - errors['ffmpeg'] = 'FFmpeg could not read the file.' - - if wav is None: - try: - wav, sr = ta.load(str(track)) - except RuntimeError as err: - errors['torchaudio'] = err.args[0] - else: - wav = convert_audio(wav, sr, samplerate, audio_channels) - - if wav is None: - print(f"Could not load file {track}. " - "Maybe it is not a supported file format? ") - for backend, error in errors.items(): - print(f"When trying to load using {backend}, got the following error: {error}") - sys.exit(1) - return wav - - -def main(): - parser = argparse.ArgumentParser("demucs.separate", - description="Separate the sources for the given tracks") - parser.add_argument("tracks", nargs='+', type=Path, default=[], help='Path to tracks') - add_model_flags(parser) - parser.add_argument("-v", "--verbose", action="store_true") - parser.add_argument("-o", - "--out", - type=Path, - default=Path("separated"), - help="Folder where to put extracted tracks. A subfolder " - "with the model name will be created.") - parser.add_argument("-d", - "--device", - default="cuda" if th.cuda.is_available() else "cpu", - help="Device to use, default is cuda if available else cpu") - parser.add_argument("--shifts", - default=1, - type=int, - help="Number of random shifts for equivariant stabilization." - "Increase separation time but improves quality for Demucs. 10 was used " - "in the original paper.") - parser.add_argument("--overlap", - default=0.25, - type=float, - help="Overlap between the splits.") - split_group = parser.add_mutually_exclusive_group() - split_group.add_argument("--no-split", - action="store_false", - dest="split", - default=True, - help="Doesn't split audio in chunks. " - "This can use large amounts of memory.") - split_group.add_argument("--segment", type=int, - help="Set split size of each chunk. " - "This can help save memory of graphic card. ") - parser.add_argument("--two-stems", - dest="stem", metavar="STEM", - help="Only separate audio into {STEM} and no_{STEM}. ") - group = parser.add_mutually_exclusive_group() - group.add_argument("--int24", action="store_true", - help="Save wav output as 24 bits wav.") - group.add_argument("--float32", action="store_true", - help="Save wav output as float32 (2x bigger).") - parser.add_argument("--clip-mode", default="rescale", choices=["rescale", "clamp"], - help="Strategy for avoiding clipping: rescaling entire signal " - "if necessary (rescale) or hard clipping (clamp).") - parser.add_argument("--mp3", action="store_true", - help="Convert the output wavs to mp3.") - parser.add_argument("--mp3-bitrate", - default=320, - type=int, - help="Bitrate of converted mp3.") - parser.add_argument("-j", "--jobs", - default=0, - type=int, - help="Number of jobs. This can increase memory usage but will " - "be much faster when multiple cores are available.") - - args = parser.parse_args() - - try: - model = get_model_from_args(args) - except ModelLoadingError as error: - fatal(error.args[0]) - - if args.segment is not None and args.segment < 8: - fatal('Segment must greater than 8. ') - - if isinstance(model, BagOfModels): - if args.segment is not None: - for sub in model.models: - sub.segment = args.segment - else: - if args.segment is not None: - sub.segment = args.segment - - model.cpu() - model.eval() - - if args.stem is not None and args.stem not in model.sources: - fatal( - 'error: stem "{stem}" is not in selected model. STEM must be one of {sources}.'.format( - stem=args.stem, sources=', '.join(model.sources))) - out = args.out / args.name - out.mkdir(parents=True, exist_ok=True) - print(f"Separated tracks will be stored in {out.resolve()}") - for track in args.tracks: - if not track.exists(): - print( - f"File {track} does not exist. If the path contains spaces, " - "please try again after surrounding the entire path with quotes \"\".", - file=sys.stderr) - continue - print(f"Separating track {track}") - wav = load_track(track, model.audio_channels, model.samplerate) - - ref = wav.mean(0) - wav = (wav - ref.mean()) / ref.std() - sources = apply_model(model, wav[None], device=args.device, shifts=args.shifts, - split=args.split, overlap=args.overlap, progress=True, - num_workers=args.jobs)[0] - sources = sources * ref.std() + ref.mean() - - track_folder = out / track.name.rsplit(".", 1)[0] - track_folder.mkdir(exist_ok=True) - if args.mp3: - ext = ".mp3" - else: - ext = ".wav" - kwargs = { - 'samplerate': model.samplerate, - 'bitrate': args.mp3_bitrate, - 'clip': args.clip_mode, - 'as_float': args.float32, - 'bits_per_sample': 24 if args.int24 else 16, - } - if args.stem is None: - for source, name in zip(sources, model.sources): - stem = str(track_folder / (name + ext)) - save_audio(source, stem, **kwargs) - else: - sources = list(sources) - stem = str(track_folder / (args.stem + ext)) - save_audio(sources.pop(model.sources.index(args.stem)), stem, **kwargs) - # Warning : after poping the stem, selected stem is no longer in the list 'sources' - other_stem = th.zeros_like(sources[0]) - for i in sources: - other_stem += i - stem = str(track_folder / ("no_" + args.stem + ext)) - save_audio(other_stem, stem, **kwargs) - - -if __name__ == "__main__": - main() diff --git a/demucs/solver.py b/demucs/solver.py deleted file mode 100644 index 9970615..0000000 --- a/demucs/solver.py +++ /dev/null @@ -1,404 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -"""Main training loop.""" - -import logging - -from dora import get_xp -from dora.utils import write_and_rename -from dora.log import LogProgress, bold -import torch -import torch.nn.functional as F - -from . import augment, distrib, states, pretrained -from .apply import apply_model -from .ema import ModelEMA -from .evaluate import evaluate, new_sdr -from .svd import svd_penalty -from .utils import pull_metric, EMA - -logger = logging.getLogger(__name__) - - -def _summary(metrics): - return " | ".join(f"{key.capitalize()}={val}" for key, val in metrics.items()) - - -class Solver(object): - def __init__(self, loaders, model, optimizer, args): - self.args = args - self.loaders = loaders - - self.model = model - self.optimizer = optimizer - self.quantizer = states.get_quantizer(self.model, args.quant, self.optimizer) - self.dmodel = distrib.wrap(model) - self.device = next(iter(self.model.parameters())).device - - # Exponential moving average of the model, either updated every batch or epoch. - # The best model from all the EMAs and the original one is kept based on the valid - # loss for the final best model. - self.emas = {'batch': [], 'epoch': []} - for kind in self.emas.keys(): - decays = getattr(args.ema, kind) - device = self.device if kind == 'batch' else 'cpu' - if decays: - for decay in decays: - self.emas[kind].append(ModelEMA(self.model, decay, device=device)) - - # data augment - augments = [augment.Shift(shift=int(args.dset.samplerate * args.dset.shift), - same=args.augment.shift_same)] - if args.augment.flip: - augments += [augment.FlipChannels(), augment.FlipSign()] - for aug in ['scale', 'remix']: - kw = getattr(args.augment, aug) - if kw.proba: - augments.append(getattr(augment, aug.capitalize())(**kw)) - self.augment = torch.nn.Sequential(*augments) - - xp = get_xp() - self.folder = xp.folder - # Checkpoints - self.checkpoint_file = xp.folder / 'checkpoint.th' - self.best_file = xp.folder / 'best.th' - logger.debug("Checkpoint will be saved to %s", self.checkpoint_file.resolve()) - self.best_state = None - self.best_changed = False - - self.link = xp.link - self.history = self.link.history - - self._reset() - - def _serialize(self, epoch): - package = {} - package['state'] = self.model.state_dict() - package['optimizer'] = self.optimizer.state_dict() - package['history'] = self.history - package['best_state'] = self.best_state - package['args'] = self.args - for kind, emas in self.emas.items(): - for k, ema in enumerate(emas): - package[f'ema_{kind}_{k}'] = ema.state_dict() - with write_and_rename(self.checkpoint_file) as tmp: - torch.save(package, tmp) - - save_every = self.args.save_every - if save_every and (epoch + 1) % save_every == 0 and epoch + 1 != self.args.epochs: - with write_and_rename(self.folder / f'checkpoint_{epoch + 1}.th') as tmp: - torch.save(package, tmp) - - if self.best_changed: - # Saving only the latest best model. - with write_and_rename(self.best_file) as tmp: - package = states.serialize_model(self.model, self.args) - package['state'] = self.best_state - torch.save(package, tmp) - self.best_changed = False - - def _reset(self): - """Reset state of the solver, potentially using checkpoint.""" - if self.checkpoint_file.exists(): - logger.info(f'Loading checkpoint model: {self.checkpoint_file}') - package = torch.load(self.checkpoint_file, 'cpu') - self.model.load_state_dict(package['state']) - self.optimizer.load_state_dict(package['optimizer']) - self.history[:] = package['history'] - self.best_state = package['best_state'] - for kind, emas in self.emas.items(): - for k, ema in enumerate(emas): - ema.load_state_dict(package[f'ema_{kind}_{k}']) - elif self.args.continue_pretrained: - model = pretrained.get_model( - name=self.args.continue_pretrained, - repo=self.args.pretrained_repo) - self.model.load_state_dict(model.state_dict()) - elif self.args.continue_from: - name = 'checkpoint.th' - root = self.folder.parent - cf = root / str(self.args.continue_from) / name - logger.info("Loading from %s", cf) - package = torch.load(cf, 'cpu') - self.best_state = package['best_state'] - if self.args.continue_best: - self.model.load_state_dict(package['best_state'], strict=False) - else: - self.model.load_state_dict(package['state'], strict=False) - if self.args.continue_opt: - self.optimizer.load_state_dict(package['optimizer']) - - def _format_train(self, metrics: dict) -> dict: - """Formatting for train/valid metrics.""" - losses = { - 'loss': format(metrics['loss'], ".4f"), - 'reco': format(metrics['reco'], ".4f"), - } - if 'nsdr' in metrics: - losses['nsdr'] = format(metrics['nsdr'], ".3f") - if self.quantizer is not None: - losses['ms'] = format(metrics['ms'], ".2f") - if 'grad' in metrics: - losses['grad'] = format(metrics['grad'], ".4f") - if 'best' in metrics: - losses['best'] = format(metrics['best'], '.4f') - if 'bname' in metrics: - losses['bname'] = metrics['bname'] - if 'penalty' in metrics: - losses['penalty'] = format(metrics['penalty'], ".4f") - if 'hloss' in metrics: - losses['hloss'] = format(metrics['hloss'], ".4f") - return losses - - def _format_test(self, metrics: dict) -> dict: - """Formatting for test metrics.""" - losses = {} - if 'sdr' in metrics: - losses['sdr'] = format(metrics['sdr'], '.3f') - if 'nsdr' in metrics: - losses['nsdr'] = format(metrics['nsdr'], '.3f') - for source in self.model.sources: - key = f'sdr_{source}' - if key in metrics: - losses[key] = format(metrics[key], '.3f') - key = f'nsdr_{source}' - if key in metrics: - losses[key] = format(metrics[key], '.3f') - return losses - - def train(self): - # Optimizing the model - if self.history: - logger.info("Replaying metrics from previous run") - for epoch, metrics in enumerate(self.history): - formatted = self._format_train(metrics['train']) - logger.info( - bold(f'Train Summary | Epoch {epoch + 1} | {_summary(formatted)}')) - formatted = self._format_train(metrics['valid']) - logger.info( - bold(f'Valid Summary | Epoch {epoch + 1} | {_summary(formatted)}')) - if 'test' in metrics: - formatted = self._format_test(metrics['test']) - if formatted: - logger.info(bold(f"Test Summary | Epoch {epoch + 1} | {_summary(formatted)}")) - - epoch = 0 - for epoch in range(len(self.history), self.args.epochs): - # Train one epoch - self.model.train() # Turn on BatchNorm & Dropout - metrics = {} - logger.info('-' * 70) - logger.info("Training...") - metrics['train'] = self._run_one_epoch(epoch) - formatted = self._format_train(metrics['train']) - logger.info( - bold(f'Train Summary | Epoch {epoch + 1} | {_summary(formatted)}')) - - # Cross validation - logger.info('-' * 70) - logger.info('Cross validation...') - self.model.eval() # Turn off Batchnorm & Dropout - with torch.no_grad(): - valid = self._run_one_epoch(epoch, train=False) - bvalid = valid - bname = 'main' - state = states.copy_state(self.model.state_dict()) - metrics['valid'] = {} - metrics['valid']['main'] = valid - key = self.args.test.metric - for kind, emas in self.emas.items(): - for k, ema in enumerate(emas): - with ema.swap(): - valid = self._run_one_epoch(epoch, train=False) - name = f'ema_{kind}_{k}' - metrics['valid'][name] = valid - a = valid[key] - b = bvalid[key] - if key.startswith('nsdr'): - a = -a - b = -b - if a < b: - bvalid = valid - state = ema.state - bname = name - metrics['valid'].update(bvalid) - metrics['valid']['bname'] = bname - - valid_loss = metrics['valid'][key] - mets = pull_metric(self.link.history, f'valid.{key}') + [valid_loss] - if key.startswith('nsdr'): - best_loss = max(mets) - else: - best_loss = min(mets) - metrics['valid']['best'] = best_loss - if self.args.svd.penalty > 0: - kw = dict(self.args.svd) - kw.pop('penalty') - with torch.no_grad(): - penalty = svd_penalty(self.model, exact=True, **kw) - metrics['valid']['penalty'] = penalty - - formatted = self._format_train(metrics['valid']) - logger.info( - bold(f'Valid Summary | Epoch {epoch + 1} | {_summary(formatted)}')) - - # Save the best model - if valid_loss == best_loss or self.args.dset.train_valid: - logger.info(bold('New best valid loss %.4f'), valid_loss) - self.best_state = states.copy_state(state) - self.best_changed = True - - # Eval model every `test.every` epoch or on last epoch - should_eval = (epoch + 1) % self.args.test.every == 0 - is_last = epoch == self.args.epochs - 1 - reco = metrics['valid']['main']['reco'] - # Tries to detect divergence in a reliable way and finish job - # not to waste compute. - div = epoch >= 180 and reco > 0.18 - div = div or epoch >= 100 and reco > 0.25 - div = div and self.args.optim.loss == 'l1' - if div: - logger.warning("Finishing training early because valid loss is too high.") - is_last = True - if should_eval or is_last: - # Evaluate on the testset - logger.info('-' * 70) - logger.info('Evaluating on the test set...') - # We switch to the best known model for testing - if self.args.test.best: - state = self.best_state - else: - state = states.copy_state(self.model.state_dict()) - compute_sdr = self.args.test.sdr and is_last - with states.swap_state(self.model, state): - with torch.no_grad(): - metrics['test'] = evaluate(self, compute_sdr=compute_sdr) - formatted = self._format_test(metrics['test']) - logger.info(bold(f"Test Summary | Epoch {epoch + 1} | {_summary(formatted)}")) - self.link.push_metrics(metrics) - - if distrib.rank == 0: - # Save model each epoch - self._serialize(epoch) - logger.debug("Checkpoint saved to %s", self.checkpoint_file.resolve()) - if is_last: - break - - def _run_one_epoch(self, epoch, train=True): - args = self.args - data_loader = self.loaders['train'] if train else self.loaders['valid'] - # get a different order for distributed training, otherwise this will get ignored - data_loader.sampler.epoch = epoch - - label = ["Valid", "Train"][train] - name = label + f" | Epoch {epoch + 1}" - total = len(data_loader) - if args.max_batches: - total = min(total, args.max_batches) - logprog = LogProgress(logger, data_loader, total=total, - updates=self.args.misc.num_prints, name=name) - averager = EMA() - - for idx, sources in enumerate(logprog): - sources = sources.to(self.device) - if train: - sources = self.augment(sources) - mix = sources.sum(dim=1) - else: - mix = sources[:, 0] - sources = sources[:, 1:] - - if not train and self.args.valid_apply: - estimate = apply_model(self.model, mix, split=self.args.test.split, overlap=0) - else: - estimate = self.dmodel(mix) - if train and hasattr(self.model, 'transform_target'): - sources = self.model.transform_target(mix, sources) - assert estimate.shape == sources.shape, (estimate.shape, sources.shape) - dims = tuple(range(2, sources.dim())) - - if args.optim.loss == 'l1': - loss = F.l1_loss(estimate, sources, reduction='none') - loss = loss.mean(dims).mean(0) - reco = loss - elif args.optim.loss == 'mse': - loss = F.mse_loss(estimate, sources, reduction='none') - loss = loss.mean(dims) - reco = loss**0.5 - reco = reco.mean(0) - else: - raise ValueError(f"Invalid loss {self.args.loss}") - weights = torch.tensor(args.weights).to(sources) - loss = (loss * weights).sum() / weights.sum() - - ms = 0 - if self.quantizer is not None: - ms = self.quantizer.model_size() - if args.quant.diffq: - loss += args.quant.diffq * ms - - losses = {} - losses['reco'] = (reco * weights).sum() / weights.sum() - losses['ms'] = ms - - if not train: - nsdrs = new_sdr(sources, estimate.detach()).mean(0) - total = 0 - for source, nsdr, w in zip(self.model.sources, nsdrs, weights): - losses[f'nsdr_{source}'] = nsdr - total += w * nsdr - losses['nsdr'] = total / weights.sum() - - if train and args.svd.penalty > 0: - kw = dict(args.svd) - kw.pop('penalty') - penalty = svd_penalty(self.model, **kw) - losses['penalty'] = penalty - loss += args.svd.penalty * penalty - - losses['loss'] = loss - - for k, source in enumerate(self.model.sources): - losses[f'reco_{source}'] = reco[k] - - # optimize model in training mode - if train: - loss.backward() - grad_norm = 0 - grads = [] - for p in self.model.parameters(): - if p.grad is not None: - grad_norm += p.grad.data.norm()**2 - grads.append(p.grad.data) - losses['grad'] = grad_norm ** 0.5 - if args.optim.clip_grad: - torch.nn.utils.clip_grad_norm_( - self.model.parameters(), - args.optim.clip_grad) - - if self.args.flag == 'uns': - for n, p in self.model.named_parameters(): - if p.grad is None: - print('no grad', n) - self.optimizer.step() - self.optimizer.zero_grad() - for ema in self.emas['batch']: - ema.update() - losses = averager(losses) - logs = self._format_train(losses) - logprog.update(**logs) - # Just in case, clear some memory - del loss, estimate, reco, ms - if args.max_batches == idx: - break - if self.args.debug and train: - break - if self.args.flag == 'debug': - break - if train: - for ema in self.emas['epoch']: - ema.update() - return distrib.average(losses, idx + 1) diff --git a/demucs/spec.py b/demucs/spec.py deleted file mode 100644 index 85e5dc9..0000000 --- a/demucs/spec.py +++ /dev/null @@ -1,41 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -"""Conveniance wrapper to perform STFT and iSTFT""" - -import torch as th - - -def spectro(x, n_fft=512, hop_length=None, pad=0): - *other, length = x.shape - x = x.reshape(-1, length) - z = th.stft(x, - n_fft * (1 + pad), - hop_length or n_fft // 4, - window=th.hann_window(n_fft).to(x), - win_length=n_fft, - normalized=True, - center=True, - return_complex=True, - pad_mode='reflect') - _, freqs, frame = z.shape - return z.view(*other, freqs, frame) - - -def ispectro(z, hop_length=None, length=None, pad=0): - *other, freqs, frames = z.shape - n_fft = 2 * freqs - 2 - z = z.view(-1, freqs, frames) - win_length = n_fft // (1 + pad) - x = th.istft(z, - n_fft, - hop_length, - window=th.hann_window(win_length).to(z.real), - win_length=win_length, - normalized=True, - length=length, - center=True) - _, length = x.shape - return x.view(*other, length) diff --git a/demucs/states.py b/demucs/states.py deleted file mode 100644 index db17a18..0000000 --- a/demucs/states.py +++ /dev/null @@ -1,148 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -""" -Utilities to save and load models. -""" -from contextlib import contextmanager - -import functools -import hashlib -import inspect -import io -from pathlib import Path -import warnings - -from omegaconf import OmegaConf -from diffq import DiffQuantizer, UniformQuantizer, restore_quantized_state -import torch - - -def get_quantizer(model, args, optimizer=None): - """Return the quantizer given the XP quantization args.""" - quantizer = None - if args.diffq: - quantizer = DiffQuantizer( - model, min_size=args.min_size, group_size=args.group_size) - if optimizer is not None: - quantizer.setup_optimizer(optimizer) - elif args.qat: - quantizer = UniformQuantizer( - model, bits=args.qat, min_size=args.min_size) - return quantizer - - -def load_model(path_or_package, strict=False): - """Load a model from the given serialized model, either given as a dict (already loaded) - or a path to a file on disk.""" - if isinstance(path_or_package, dict): - package = path_or_package - elif isinstance(path_or_package, (str, Path)): - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - path = path_or_package - package = torch.load(path, 'cpu') - else: - raise ValueError(f"Invalid type for {path_or_package}.") - - klass = package["klass"] - args = package["args"] - kwargs = package["kwargs"] - - if strict: - model = klass(*args, **kwargs) - else: - sig = inspect.signature(klass) - for key in list(kwargs): - if key not in sig.parameters: - warnings.warn("Dropping inexistant parameter " + key) - del kwargs[key] - model = klass(*args, **kwargs) - - state = package["state"] - - set_state(model, state) - return model - - -def get_state(model, quantizer, half=False): - """Get the state from a model, potentially with quantization applied. - If `half` is True, model are stored as half precision, which shouldn't impact performance - but half the state size.""" - if quantizer is None: - dtype = torch.half if half else None - state = {k: p.data.to(device='cpu', dtype=dtype) for k, p in model.state_dict().items()} - else: - state = quantizer.get_quantized_state() - state['__quantized'] = True - return state - - -def set_state(model, state, quantizer=None): - """Set the state on a given model.""" - if state.get('__quantized'): - if quantizer is not None: - quantizer.restore_quantized_state(model, state['quantized']) - else: - restore_quantized_state(model, state) - else: - model.load_state_dict(state) - return state - - -def save_with_checksum(content, path): - """Save the given value on disk, along with a sha256 hash. - Should be used with the output of either `serialize_model` or `get_state`.""" - buf = io.BytesIO() - torch.save(content, buf) - sig = hashlib.sha256(buf.getvalue()).hexdigest()[:8] - - path = path.parent / (path.stem + "-" + sig + path.suffix) - path.write_bytes(buf.getvalue()) - - -def serialize_model(model, training_args, quantizer=None, half=True): - args, kwargs = model._init_args_kwargs - klass = model.__class__ - - state = get_state(model, quantizer, half) - return { - 'klass': klass, - 'args': args, - 'kwargs': kwargs, - 'state': state, - 'training_args': OmegaConf.to_container(training_args, resolve=True), - } - - -def copy_state(state): - return {k: v.cpu().clone() for k, v in state.items()} - - -@contextmanager -def swap_state(model, state): - """ - Context manager that swaps the state of a model, e.g: - - # model is in old state - with swap_state(model, new_state): - # model in new state - # model back to old state - """ - old_state = copy_state(model.state_dict()) - model.load_state_dict(state, strict=False) - try: - yield - finally: - model.load_state_dict(old_state) - - -def capture_init(init): - @functools.wraps(init) - def __init__(self, *args, **kwargs): - self._init_args_kwargs = (args, kwargs) - init(self, *args, **kwargs) - - return __init__ diff --git a/demucs/svd.py b/demucs/svd.py deleted file mode 100644 index 96a74e2..0000000 --- a/demucs/svd.py +++ /dev/null @@ -1,83 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -"""Ways to make the model stronger.""" -import random -import torch - - -def power_iteration(m, niters=1, bs=1): - """This is the power method. batch size is used to try multiple starting point in parallel.""" - assert m.dim() == 2 - assert m.shape[0] == m.shape[1] - dim = m.shape[0] - b = torch.randn(dim, bs, device=m.device, dtype=m.dtype) - - for _ in range(niters): - n = m.mm(b) - norm = n.norm(dim=0, keepdim=True) - b = n / (1e-10 + norm) - - return norm.mean() - - -# We need a shared RNG to make sure all the distributed worker will skip the penalty together, -# as otherwise we wouldn't get any speed up. -penalty_rng = random.Random(1234) - - -def svd_penalty(model, min_size=0.1, dim=1, niters=2, powm=False, convtr=True, - proba=1, conv_only=False, exact=False, bs=1): - """ - Penalty on the largest singular value for a layer. - Args: - - model: model to penalize - - min_size: minimum size in MB of a layer to penalize. - - dim: projection dimension for the svd_lowrank. Higher is better but slower. - - niters: number of iterations in the algorithm used by svd_lowrank. - - powm: use power method instead of lowrank SVD, my own experience - is that it is both slower and less stable. - - convtr: when True, differentiate between Conv and Transposed Conv. - this is kept for compatibility with older experiments. - - proba: probability to apply the penalty. - - conv_only: only apply to conv and conv transposed, not LSTM - (might not be reliable for other models than Demucs). - - exact: use exact SVD (slow but useful at validation). - - bs: batch_size for power method. - """ - total = 0 - if penalty_rng.random() > proba: - return 0. - - for m in model.modules(): - for name, p in m.named_parameters(recurse=False): - if p.numel() / 2**18 < min_size: - continue - if convtr: - if isinstance(m, (torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d)): - if p.dim() in [3, 4]: - p = p.transpose(0, 1).contiguous() - if p.dim() == 3: - p = p.view(len(p), -1) - elif p.dim() == 4: - p = p.view(len(p), -1) - elif p.dim() == 1: - continue - elif conv_only: - continue - assert p.dim() == 2, (name, p.shape) - if exact: - estimate = torch.svd(p, compute_uv=False)[1].pow(2).max() - elif powm: - a, b = p.shape - if a < b: - n = p.mm(p.t()) - else: - n = p.t().mm(p) - estimate = power_iteration(n, niters, bs) - else: - estimate = torch.svd_lowrank(p, dim, niters)[1][0].pow(2) - total += estimate - return total / proba diff --git a/demucs/utils.py b/demucs/utils.py deleted file mode 100644 index 3f2afaa..0000000 --- a/demucs/utils.py +++ /dev/null @@ -1,131 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -from collections import defaultdict -from contextlib import contextmanager -import math -import os -import tempfile -import typing as tp - -import torch -from torch.nn import functional as F - - -def unfold(a, kernel_size, stride): - """Given input of size [*OT, T], output Tensor of size [*OT, F, K] - with K the kernel size, by extracting frames with the given stride. - - This will pad the input so that `F = ceil(T / K)`. - - see https://github.com/pytorch/pytorch/issues/60466 - """ - *shape, length = a.shape - n_frames = math.ceil(length / stride) - tgt_length = (n_frames - 1) * stride + kernel_size - a = F.pad(a, (0, tgt_length - length)) - strides = list(a.stride()) - assert strides[-1] == 1, 'data should be contiguous' - strides = strides[:-1] + [stride, 1] - return a.as_strided([*shape, n_frames, kernel_size], strides) - - -def center_trim(tensor: torch.Tensor, reference: tp.Union[torch.Tensor, int]): - """ - Center trim `tensor` with respect to `reference`, along the last dimension. - `reference` can also be a number, representing the length to trim to. - If the size difference != 0 mod 2, the extra sample is removed on the right side. - """ - ref_size: int - if isinstance(reference, torch.Tensor): - ref_size = reference.size(-1) - else: - ref_size = reference - delta = tensor.size(-1) - ref_size - if delta < 0: - raise ValueError("tensor must be larger than reference. " f"Delta is {delta}.") - if delta: - tensor = tensor[..., delta // 2:-(delta - delta // 2)] - return tensor - - -def pull_metric(history: tp.List[dict], name: str): - out = [] - for metrics in history: - metric = metrics - for part in name.split("."): - metric = metric[part] - out.append(metric) - return out - - -def EMA(beta: float = 1): - """ - Exponential Moving Average callback. - Returns a single function that can be called to repeatidly update the EMA - with a dict of metrics. The callback will return - the new averaged dict of metrics. - - Note that for `beta=1`, this is just plain averaging. - """ - fix: tp.Dict[str, float] = defaultdict(float) - total: tp.Dict[str, float] = defaultdict(float) - - def _update(metrics: dict, weight: float = 1) -> dict: - nonlocal total, fix - for key, value in metrics.items(): - total[key] = total[key] * beta + weight * float(value) - fix[key] = fix[key] * beta + weight - return {key: tot / fix[key] for key, tot in total.items()} - return _update - - -def sizeof_fmt(num: float, suffix: str = 'B'): - """ - Given `num` bytes, return human readable size. - Taken from https://stackoverflow.com/a/1094933 - """ - for unit in ['', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi']: - if abs(num) < 1024.0: - return "%3.1f%s%s" % (num, unit, suffix) - num /= 1024.0 - return "%.1f%s%s" % (num, 'Yi', suffix) - - -@contextmanager -def temp_filenames(count: int, delete=True): - names = [] - try: - for _ in range(count): - names.append(tempfile.NamedTemporaryFile(delete=False).name) - yield names - finally: - if delete: - for name in names: - os.unlink(name) - - -class DummyPoolExecutor: - class DummyResult: - def __init__(self, func, *args, **kwargs): - self.func = func - self.args = args - self.kwargs = kwargs - - def result(self): - return self.func(*self.args, **self.kwargs) - - def __init__(self, workers=0): - pass - - def submit(self, func, *args, **kwargs): - return DummyPoolExecutor.DummyResult(func, *args, **kwargs) - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_value, exc_tb): - return diff --git a/demucs/wav.py b/demucs/wav.py deleted file mode 100644 index 1c023a7..0000000 --- a/demucs/wav.py +++ /dev/null @@ -1,242 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -"""Loading wav based datasets, including MusdbHQ.""" - -from collections import OrderedDict -import hashlib -import math -import json -import os -from pathlib import Path -import tqdm - -import musdb -import julius -import torch as th -from torch import distributed -import torchaudio as ta -from torch.nn import functional as F - -from .audio import convert_audio_channels -from . import distrib - -MIXTURE = "mixture" -EXT = ".wav" - - -def _track_metadata(track, sources, normalize=True, ext=EXT): - track_length = None - track_samplerate = None - mean = 0 - std = 1 - for source in sources + [MIXTURE]: - file = track / f"{source}{ext}" - try: - info = ta.info(str(file)) - except RuntimeError: - print(file) - raise - length = info.num_frames - if track_length is None: - track_length = length - track_samplerate = info.sample_rate - elif track_length != length: - raise ValueError( - f"Invalid length for file {file}: " - f"expecting {track_length} but got {length}.") - elif info.sample_rate != track_samplerate: - raise ValueError( - f"Invalid sample rate for file {file}: " - f"expecting {track_samplerate} but got {info.sample_rate}.") - if source == MIXTURE and normalize: - try: - wav, _ = ta.load(str(file)) - except RuntimeError: - print(file) - raise - wav = wav.mean(0) - mean = wav.mean().item() - std = wav.std().item() - - return {"length": length, "mean": mean, "std": std, "samplerate": track_samplerate} - - -def build_metadata(path, sources, normalize=True, ext=EXT): - """ - Build the metadata for `Wavset`. - - Args: - path (str or Path): path to dataset. - sources (list[str]): list of sources to look for. - normalize (bool): if True, loads full track and store normalization - values based on the mixture file. - ext (str): extension of audio files (default is .wav). - """ - - meta = {} - path = Path(path) - pendings = [] - from concurrent.futures import ThreadPoolExecutor - with ThreadPoolExecutor(8) as pool: - for root, folders, files in os.walk(path, followlinks=True): - root = Path(root) - if root.name.startswith('.') or folders or root == path: - continue - name = str(root.relative_to(path)) - pendings.append((name, pool.submit(_track_metadata, root, sources, normalize, ext))) - # meta[name] = _track_metadata(root, sources, normalize, ext) - for name, pending in tqdm.tqdm(pendings, ncols=120): - meta[name] = pending.result() - return meta - - -class Wavset: - def __init__( - self, - root, metadata, sources, - segment=None, shift=None, normalize=True, - samplerate=44100, channels=2, ext=EXT): - """ - Waveset (or mp3 set for that matter). Can be used to train - with arbitrary sources. Each track should be one folder inside of `path`. - The folder should contain files named `{source}.{ext}`. - - Args: - root (Path or str): root folder for the dataset. - metadata (dict): output from `build_metadata`. - sources (list[str]): list of source names. - segment (None or float): segment length in seconds. If `None`, returns entire tracks. - shift (None or float): stride in seconds bewteen samples. - normalize (bool): normalizes input audio, **based on the metadata content**, - i.e. the entire track is normalized, not individual extracts. - samplerate (int): target sample rate. if the file sample rate - is different, it will be resampled on the fly. - channels (int): target nb of channels. if different, will be - changed onthe fly. - ext (str): extension for audio files (default is .wav). - - samplerate and channels are converted on the fly. - """ - self.root = Path(root) - self.metadata = OrderedDict(metadata) - self.segment = segment - self.shift = shift or segment - self.normalize = normalize - self.sources = sources - self.channels = channels - self.samplerate = samplerate - self.ext = ext - self.num_examples = [] - for name, meta in self.metadata.items(): - track_duration = meta['length'] / meta['samplerate'] - if segment is None or track_duration < segment: - examples = 1 - else: - examples = int(math.ceil((track_duration - self.segment) / self.shift) + 1) - self.num_examples.append(examples) - - def __len__(self): - return sum(self.num_examples) - - def get_file(self, name, source): - return self.root / name / f"{source}{self.ext}" - - def __getitem__(self, index): - for name, examples in zip(self.metadata, self.num_examples): - if index >= examples: - index -= examples - continue - meta = self.metadata[name] - num_frames = -1 - offset = 0 - if self.segment is not None: - offset = int(meta['samplerate'] * self.shift * index) - num_frames = int(math.ceil(meta['samplerate'] * self.segment)) - wavs = [] - for source in self.sources: - file = self.get_file(name, source) - wav, _ = ta.load(str(file), frame_offset=offset, num_frames=num_frames) - wav = convert_audio_channels(wav, self.channels) - wavs.append(wav) - - example = th.stack(wavs) - example = julius.resample_frac(example, meta['samplerate'], self.samplerate) - if self.normalize: - example = (example - meta['mean']) / meta['std'] - if self.segment: - length = int(self.segment * self.samplerate) - example = example[..., :length] - example = F.pad(example, (0, length - example.shape[-1])) - return example - - -def get_wav_datasets(args): - """Extract the wav datasets from the XP arguments.""" - sig = hashlib.sha1(str(args.wav).encode()).hexdigest()[:8] - metadata_file = Path(args.metadata) / ('wav_' + sig + ".json") - train_path = Path(args.wav) / "train" - valid_path = Path(args.wav) / "valid" - if not metadata_file.is_file() and distrib.rank == 0: - metadata_file.parent.mkdir(exist_ok=True, parents=True) - train = build_metadata(train_path, args.sources) - valid = build_metadata(valid_path, args.sources) - json.dump([train, valid], open(metadata_file, "w")) - if distrib.world_size > 1: - distributed.barrier() - train, valid = json.load(open(metadata_file)) - if args.full_cv: - kw_cv = {} - else: - kw_cv = {'segment': args.segment, 'shift': args.shift} - train_set = Wavset(train_path, train, args.sources, - segment=args.segment, shift=args.shift, - samplerate=args.samplerate, channels=args.channels, - normalize=args.normalize) - valid_set = Wavset(valid_path, valid, [MIXTURE] + list(args.sources), - samplerate=args.samplerate, channels=args.channels, - normalize=args.normalize, **kw_cv) - return train_set, valid_set - - -def _get_musdb_valid(): - # Return musdb valid set. - import yaml - setup_path = Path(musdb.__path__[0]) / 'configs' / 'mus.yaml' - setup = yaml.safe_load(open(setup_path, 'r')) - return setup['validation_tracks'] - - -def get_musdb_wav_datasets(args): - """Extract the musdb dataset from the XP arguments.""" - sig = hashlib.sha1(str(args.musdb).encode()).hexdigest()[:8] - metadata_file = Path(args.metadata) / ('musdb_' + sig + ".json") - root = Path(args.musdb) / "train" - if not metadata_file.is_file() and distrib.rank == 0: - metadata_file.parent.mkdir(exist_ok=True, parents=True) - metadata = build_metadata(root, args.sources) - json.dump(metadata, open(metadata_file, "w")) - if distrib.world_size > 1: - distributed.barrier() - metadata = json.load(open(metadata_file)) - - valid_tracks = _get_musdb_valid() - if args.train_valid: - metadata_train = metadata - else: - metadata_train = {name: meta for name, meta in metadata.items() if name not in valid_tracks} - metadata_valid = {name: meta for name, meta in metadata.items() if name in valid_tracks} - if args.full_cv: - kw_cv = {} - else: - kw_cv = {'segment': args.segment, 'shift': args.shift} - train_set = Wavset(root, metadata_train, args.sources, - segment=args.segment, shift=args.shift, - samplerate=args.samplerate, channels=args.channels, - normalize=args.normalize) - valid_set = Wavset(root, metadata_valid, [MIXTURE] + list(args.sources), - samplerate=args.samplerate, channels=args.channels, - normalize=args.normalize, **kw_cv) - return train_set, valid_set diff --git a/demucs/wdemucs.py b/demucs/wdemucs.py deleted file mode 100644 index b0d799e..0000000 --- a/demucs/wdemucs.py +++ /dev/null @@ -1,9 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -# For compat -from .hdemucs import HDemucs - -WDemucs = HDemucs