From 0c9628d728e06e633e96dbf94bd79c6ecda30976 Mon Sep 17 00:00:00 2001 From: Anjok07 <68268275+Anjok07@users.noreply.github.com> Date: Fri, 31 Mar 2023 05:16:36 -0500 Subject: [PATCH] Add files via upload --- lib_v5/mdxnet.py | 140 ++++++++++++++++++++++++++++++++++++++++++++++ lib_v5/mixer.ckpt | Bin 0 -> 1208 bytes lib_v5/modules.py | 74 ++++++++++++++++++++++++ 3 files changed, 214 insertions(+) create mode 100644 lib_v5/mdxnet.py create mode 100644 lib_v5/mixer.ckpt create mode 100644 lib_v5/modules.py diff --git a/lib_v5/mdxnet.py b/lib_v5/mdxnet.py new file mode 100644 index 0000000..c0a61fe --- /dev/null +++ b/lib_v5/mdxnet.py @@ -0,0 +1,140 @@ +from abc import ABCMeta + +import torch +import torch.nn as nn +from pytorch_lightning import LightningModule +from .modules import TFC_TDF + +dim_s = 4 + +class AbstractMDXNet(LightningModule): + __metaclass__ = ABCMeta + + def __init__(self, target_name, lr, optimizer, dim_c, dim_f, dim_t, n_fft, hop_length, overlap): + super().__init__() + self.target_name = target_name + self.lr = lr + self.optimizer = optimizer + self.dim_c = dim_c + self.dim_f = dim_f + self.dim_t = dim_t + self.n_fft = n_fft + self.n_bins = n_fft // 2 + 1 + self.hop_length = hop_length + self.window = nn.Parameter(torch.hann_window(window_length=self.n_fft, periodic=True), requires_grad=False) + self.freq_pad = nn.Parameter(torch.zeros([1, dim_c, self.n_bins - self.dim_f, self.dim_t]), requires_grad=False) + + def configure_optimizers(self): + if self.optimizer == 'rmsprop': + return torch.optim.RMSprop(self.parameters(), self.lr) + + if self.optimizer == 'adamw': + return torch.optim.AdamW(self.parameters(), self.lr) + +class ConvTDFNet(AbstractMDXNet): + def __init__(self, target_name, lr, optimizer, dim_c, dim_f, dim_t, n_fft, hop_length, + num_blocks, l, g, k, bn, bias, overlap): + + super(ConvTDFNet, self).__init__( + target_name, lr, optimizer, dim_c, dim_f, dim_t, n_fft, hop_length, overlap) + self.save_hyperparameters() + + self.num_blocks = num_blocks + self.l = l + self.g = g + self.k = k + self.bn = bn + self.bias = bias + + if optimizer == 'rmsprop': + norm = nn.BatchNorm2d + + if optimizer == 'adamw': + norm = lambda input:nn.GroupNorm(2, input) + + self.n = num_blocks // 2 + scale = (2, 2) + + self.first_conv = nn.Sequential( + nn.Conv2d(in_channels=self.dim_c, out_channels=g, kernel_size=(1, 1)), + norm(g), + nn.ReLU(), + ) + + f = self.dim_f + c = g + self.encoding_blocks = nn.ModuleList() + self.ds = nn.ModuleList() + for i in range(self.n): + self.encoding_blocks.append(TFC_TDF(c, l, f, k, bn, bias=bias, norm=norm)) + self.ds.append( + nn.Sequential( + nn.Conv2d(in_channels=c, out_channels=c + g, kernel_size=scale, stride=scale), + norm(c + g), + nn.ReLU() + ) + ) + f = f // 2 + c += g + + self.bottleneck_block = TFC_TDF(c, l, f, k, bn, bias=bias, norm=norm) + + self.decoding_blocks = nn.ModuleList() + self.us = nn.ModuleList() + for i in range(self.n): + self.us.append( + nn.Sequential( + nn.ConvTranspose2d(in_channels=c, out_channels=c - g, kernel_size=scale, stride=scale), + norm(c - g), + nn.ReLU() + ) + ) + f = f * 2 + c -= g + + self.decoding_blocks.append(TFC_TDF(c, l, f, k, bn, bias=bias, norm=norm)) + + self.final_conv = nn.Sequential( + nn.Conv2d(in_channels=c, out_channels=self.dim_c, kernel_size=(1, 1)), + ) + + def forward(self, x): + + x = self.first_conv(x) + + x = x.transpose(-1, -2) + + ds_outputs = [] + for i in range(self.n): + x = self.encoding_blocks[i](x) + ds_outputs.append(x) + x = self.ds[i](x) + + x = self.bottleneck_block(x) + + for i in range(self.n): + x = self.us[i](x) + x *= ds_outputs[-i - 1] + x = self.decoding_blocks[i](x) + + x = x.transpose(-1, -2) + + x = self.final_conv(x) + + return x + +class Mixer(nn.Module): + def __init__(self, device, mixer_path): + + super(Mixer, self).__init__() + + self.linear = nn.Linear((dim_s+1)*2, dim_s*2, bias=False) + + self.load_state_dict( + torch.load(mixer_path, map_location=device) + ) + + def forward(self, x): + x = x.reshape(1,(dim_s+1)*2,-1).transpose(-1,-2) + x = self.linear(x) + return x.transpose(-1,-2).reshape(dim_s,2,-1) \ No newline at end of file diff --git a/lib_v5/mixer.ckpt b/lib_v5/mixer.ckpt new file mode 100644 index 0000000000000000000000000000000000000000..986cc4df50332460d78fda0436e8427a1728bf0e GIT binary patch literal 1208 zcmWIWW@cev;NW1u00Im`42ea_8JT6N`YDMeiFyUuIc`pT3{fbcfhjpTKPM+Oxg;|` zub9ifC?&NhHN_<}xrD2bK{Ke3F@hIpa!zJmYGRRId1_{QMoA%4a!EeWV7>U#lFXc9 zuK1$Vq|(fsl=za=yyE<#_%b7|LS_w(2zH>xVxY>z^wdHYh!!rloczR+V2B`BA!`I5 zP>rRDg_#A2FfcN=Fezk3InHI8p2b2_Y1bQ=gb9i&L6>@rWc{8>Za&=~a z+{Rr}$P-k^TU=Yn7r_bC6Q7$}k_hr@AwO7NKqCSu%1|g+TPOr}O<8JDG0>ZZ!rqL< zU|S&mDioTV znZCQyK1gcr-YdbI_b!g<+WT8~$KC+m>b)yBZri&(w{5T5?aj8Uom*{LRc7u@SW<4! zalC17*o)13>%VsHeO9w_uilaFy(;h5?#*>-w|$km!S-uf?Y^zeQu}(iH-k1Oq)aX810C0W8QWwz2j$psbr-Y*y3m*nhZ&o0HeuX0n(zLjhe z`-Fcs?-Q&LvQwU1ZWq;(x$m7Z!@lpK{`*#oKi+%yPJsQUY?gg8A#wX6MBnd~y5w(H za`3tBZTEkBg=fX>TO9vm?=C6FeO2?h_N5Ah*axm<+2@uRxbMa3AGYW89PDmb*X(`h zzhIB*^uE19R~q&xD$Lot;K`Ide%TfF-y06@-n6>gcDhxMP5rCcdp0YS?G@DT-t+xq z>)yjb6ZcFjow?Wa0RNr{wMzT`+F3SN_&RMtv9&XQb+ZF7ra%~XZ1FQV!($7Y5Zs*X z3C9qYNJA1b<^s9k;=J@wCSXzn;Q((&5CzYZ$Z1IcB!L3Zb11rQWWPzH=u86o7Fjo1 z>O(gO*@^rp1_>}D41xwsfHxbP4pfmGvo2gYP$d|^=p#`6B&YyTA1K!|gaOMW5CA$6 R7#^&k49&m{gdp`0wE#91U8?{9 literal 0 HcmV?d00001 diff --git a/lib_v5/modules.py b/lib_v5/modules.py new file mode 100644 index 0000000..4e77d2f --- /dev/null +++ b/lib_v5/modules.py @@ -0,0 +1,74 @@ +import torch +import torch.nn as nn + + +class TFC(nn.Module): + def __init__(self, c, l, k, norm): + super(TFC, self).__init__() + + self.H = nn.ModuleList() + for i in range(l): + self.H.append( + nn.Sequential( + nn.Conv2d(in_channels=c, out_channels=c, kernel_size=k, stride=1, padding=k // 2), + norm(c), + nn.ReLU(), + ) + ) + + def forward(self, x): + for h in self.H: + x = h(x) + return x + + +class DenseTFC(nn.Module): + def __init__(self, c, l, k, norm): + super(DenseTFC, self).__init__() + + self.conv = nn.ModuleList() + for i in range(l): + self.conv.append( + nn.Sequential( + nn.Conv2d(in_channels=c, out_channels=c, kernel_size=k, stride=1, padding=k // 2), + norm(c), + nn.ReLU(), + ) + ) + + def forward(self, x): + for layer in self.conv[:-1]: + x = torch.cat([layer(x), x], 1) + return self.conv[-1](x) + + +class TFC_TDF(nn.Module): + def __init__(self, c, l, f, k, bn, dense=False, bias=True, norm=nn.BatchNorm2d): + + super(TFC_TDF, self).__init__() + + self.use_tdf = bn is not None + + self.tfc = DenseTFC(c, l, k, norm) if dense else TFC(c, l, k, norm) + + if self.use_tdf: + if bn == 0: + self.tdf = nn.Sequential( + nn.Linear(f, f, bias=bias), + norm(c), + nn.ReLU() + ) + else: + self.tdf = nn.Sequential( + nn.Linear(f, f // bn, bias=bias), + norm(c), + nn.ReLU(), + nn.Linear(f // bn, f, bias=bias), + norm(c), + nn.ReLU() + ) + + def forward(self, x): + x = self.tfc(x) + return x + self.tdf(x) if self.use_tdf else x +