Add files via upload
This commit is contained in:
218
demucs/model_v2.py
Normal file
218
demucs/model_v2.py
Normal file
@@ -0,0 +1,218 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import math
|
||||
|
||||
import julius
|
||||
from torch import nn
|
||||
from .tasnet_v2 import ConvTasNet
|
||||
|
||||
from .utils import capture_init, center_trim
|
||||
|
||||
|
||||
class BLSTM(nn.Module):
|
||||
def __init__(self, dim, layers=1):
|
||||
super().__init__()
|
||||
self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim)
|
||||
self.linear = nn.Linear(2 * dim, dim)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.permute(2, 0, 1)
|
||||
x = self.lstm(x)[0]
|
||||
x = self.linear(x)
|
||||
x = x.permute(1, 2, 0)
|
||||
return x
|
||||
|
||||
|
||||
def rescale_conv(conv, reference):
|
||||
std = conv.weight.std().detach()
|
||||
scale = (std / reference)**0.5
|
||||
conv.weight.data /= scale
|
||||
if conv.bias is not None:
|
||||
conv.bias.data /= scale
|
||||
|
||||
|
||||
def rescale_module(module, reference):
|
||||
for sub in module.modules():
|
||||
if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d)):
|
||||
rescale_conv(sub, reference)
|
||||
|
||||
def auto_load_demucs_model_v2(sources, demucs_model_name):
|
||||
|
||||
if '48' in demucs_model_name:
|
||||
channels=48
|
||||
elif 'unittest' in demucs_model_name:
|
||||
channels=4
|
||||
else:
|
||||
channels=64
|
||||
|
||||
if 'tasnet' in demucs_model_name:
|
||||
init_demucs_model = ConvTasNet(sources, X=10)
|
||||
else:
|
||||
init_demucs_model = Demucs(sources, channels=channels)
|
||||
|
||||
return init_demucs_model
|
||||
|
||||
class Demucs(nn.Module):
|
||||
@capture_init
|
||||
def __init__(self,
|
||||
sources,
|
||||
audio_channels=2,
|
||||
channels=64,
|
||||
depth=6,
|
||||
rewrite=True,
|
||||
glu=True,
|
||||
rescale=0.1,
|
||||
resample=True,
|
||||
kernel_size=8,
|
||||
stride=4,
|
||||
growth=2.,
|
||||
lstm_layers=2,
|
||||
context=3,
|
||||
normalize=False,
|
||||
samplerate=44100,
|
||||
segment_length=4 * 10 * 44100):
|
||||
"""
|
||||
Args:
|
||||
sources (list[str]): list of source names
|
||||
audio_channels (int): stereo or mono
|
||||
channels (int): first convolution channels
|
||||
depth (int): number of encoder/decoder layers
|
||||
rewrite (bool): add 1x1 convolution to each encoder layer
|
||||
and a convolution to each decoder layer.
|
||||
For the decoder layer, `context` gives the kernel size.
|
||||
glu (bool): use glu instead of ReLU
|
||||
resample_input (bool): upsample x2 the input and downsample /2 the output.
|
||||
rescale (int): rescale initial weights of convolutions
|
||||
to get their standard deviation closer to `rescale`
|
||||
kernel_size (int): kernel size for convolutions
|
||||
stride (int): stride for convolutions
|
||||
growth (float): multiply (resp divide) number of channels by that
|
||||
for each layer of the encoder (resp decoder)
|
||||
lstm_layers (int): number of lstm layers, 0 = no lstm
|
||||
context (int): kernel size of the convolution in the
|
||||
decoder before the transposed convolution. If > 1,
|
||||
will provide some context from neighboring time
|
||||
steps.
|
||||
samplerate (int): stored as meta information for easing
|
||||
future evaluations of the model.
|
||||
segment_length (int): stored as meta information for easing
|
||||
future evaluations of the model. Length of the segments on which
|
||||
the model was trained.
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
self.audio_channels = audio_channels
|
||||
self.sources = sources
|
||||
self.kernel_size = kernel_size
|
||||
self.context = context
|
||||
self.stride = stride
|
||||
self.depth = depth
|
||||
self.resample = resample
|
||||
self.channels = channels
|
||||
self.normalize = normalize
|
||||
self.samplerate = samplerate
|
||||
self.segment_length = segment_length
|
||||
|
||||
self.encoder = nn.ModuleList()
|
||||
self.decoder = nn.ModuleList()
|
||||
|
||||
if glu:
|
||||
activation = nn.GLU(dim=1)
|
||||
ch_scale = 2
|
||||
else:
|
||||
activation = nn.ReLU()
|
||||
ch_scale = 1
|
||||
in_channels = audio_channels
|
||||
for index in range(depth):
|
||||
encode = []
|
||||
encode += [nn.Conv1d(in_channels, channels, kernel_size, stride), nn.ReLU()]
|
||||
if rewrite:
|
||||
encode += [nn.Conv1d(channels, ch_scale * channels, 1), activation]
|
||||
self.encoder.append(nn.Sequential(*encode))
|
||||
|
||||
decode = []
|
||||
if index > 0:
|
||||
out_channels = in_channels
|
||||
else:
|
||||
out_channels = len(self.sources) * audio_channels
|
||||
if rewrite:
|
||||
decode += [nn.Conv1d(channels, ch_scale * channels, context), activation]
|
||||
decode += [nn.ConvTranspose1d(channels, out_channels, kernel_size, stride)]
|
||||
if index > 0:
|
||||
decode.append(nn.ReLU())
|
||||
self.decoder.insert(0, nn.Sequential(*decode))
|
||||
in_channels = channels
|
||||
channels = int(growth * channels)
|
||||
|
||||
channels = in_channels
|
||||
|
||||
if lstm_layers:
|
||||
self.lstm = BLSTM(channels, lstm_layers)
|
||||
else:
|
||||
self.lstm = None
|
||||
|
||||
if rescale:
|
||||
rescale_module(self, reference=rescale)
|
||||
|
||||
def valid_length(self, length):
|
||||
"""
|
||||
Return the nearest valid length to use with the model so that
|
||||
there is no time steps left over in a convolutions, e.g. for all
|
||||
layers, size of the input - kernel_size % stride = 0.
|
||||
|
||||
If the mixture has a valid length, the estimated sources
|
||||
will have exactly the same length when context = 1. If context > 1,
|
||||
the two signals can be center trimmed to match.
|
||||
|
||||
For training, extracts should have a valid length.For evaluation
|
||||
on full tracks we recommend passing `pad = True` to :method:`forward`.
|
||||
"""
|
||||
if self.resample:
|
||||
length *= 2
|
||||
for _ in range(self.depth):
|
||||
length = math.ceil((length - self.kernel_size) / self.stride) + 1
|
||||
length = max(1, length)
|
||||
length += self.context - 1
|
||||
for _ in range(self.depth):
|
||||
length = (length - 1) * self.stride + self.kernel_size
|
||||
|
||||
if self.resample:
|
||||
length = math.ceil(length / 2)
|
||||
return int(length)
|
||||
|
||||
def forward(self, mix):
|
||||
x = mix
|
||||
|
||||
if self.normalize:
|
||||
mono = mix.mean(dim=1, keepdim=True)
|
||||
mean = mono.mean(dim=-1, keepdim=True)
|
||||
std = mono.std(dim=-1, keepdim=True)
|
||||
else:
|
||||
mean = 0
|
||||
std = 1
|
||||
|
||||
x = (x - mean) / (1e-5 + std)
|
||||
|
||||
if self.resample:
|
||||
x = julius.resample_frac(x, 1, 2)
|
||||
|
||||
saved = []
|
||||
for encode in self.encoder:
|
||||
x = encode(x)
|
||||
saved.append(x)
|
||||
if self.lstm:
|
||||
x = self.lstm(x)
|
||||
for decode in self.decoder:
|
||||
skip = center_trim(saved.pop(-1), x)
|
||||
x = x + skip
|
||||
x = decode(x)
|
||||
|
||||
if self.resample:
|
||||
x = julius.resample_frac(x, 2, 1)
|
||||
x = x * std + mean
|
||||
x = x.view(x.size(0), len(self.sources), self.audio_channels, x.size(-1))
|
||||
return x
|
||||
Reference in New Issue
Block a user