diff --git a/lib_v5/vr_network/nets.py b/lib_v5/vr_network/nets.py index 7e53b6f..3896fce 100644 --- a/lib_v5/vr_network/nets.py +++ b/lib_v5/vr_network/nets.py @@ -118,7 +118,7 @@ class CascadedASPPNet(nn.Module): self.offset = 128 - def forward(self, x, aggressiveness=None): + def forward(self, x): mix = x.detach() x = x.clone() @@ -155,17 +155,12 @@ class CascadedASPPNet(nn.Module): mode='replicate') return mask * mix, aux1 * mix, aux2 * mix else: - if aggressiveness: - mask[:, :, :aggressiveness['split_bin']] = torch.pow(mask[:, :, :aggressiveness['split_bin']], 1 + aggressiveness['value'] / 3) - mask[:, :, aggressiveness['split_bin']:] = torch.pow(mask[:, :, aggressiveness['split_bin']:], 1 + aggressiveness['value']) + return mask# * mix - return mask * mix - - def predict(self, x_mag, aggressiveness=None): - h = self.forward(x_mag, aggressiveness) + def predict_mask(self, x): + mask = self.forward(x) if self.offset > 0: - h = h[:, :, :, self.offset:-self.offset] - assert h.size()[3] > 0 + mask = mask[:, :, :, self.offset:-self.offset] - return h + return mask \ No newline at end of file diff --git a/lib_v5/vr_network/nets_new.py b/lib_v5/vr_network/nets_new.py index 1629f8a..db8260a 100644 --- a/lib_v5/vr_network/nets_new.py +++ b/lib_v5/vr_network/nets_new.py @@ -40,50 +40,32 @@ class BaseNet(nn.Module): class CascadedNet(nn.Module): - def __init__(self, n_fft, nn_architecture): + def __init__(self, n_fft, nn_arch_size, nout=32, nout_lstm=128): super(CascadedNet, self).__init__() + self.max_bin = n_fft // 2 self.output_bin = n_fft // 2 + 1 self.nin_lstm = self.max_bin // 2 self.offset = 64 - self.nn_architecture = nn_architecture + nout = 64 if nn_arch_size == 218409 else nout - print('ARC SIZE: ', nn_architecture) - - if nn_architecture == 218409: - self.stg1_low_band_net = nn.Sequential( - BaseNet(2, 32, self.nin_lstm // 2, 128), - layers.Conv2DBNActiv(32, 16, 1, 1, 0) + self.stg1_low_band_net = nn.Sequential( + BaseNet(2, nout // 2, self.nin_lstm // 2, nout_lstm), + layers.Conv2DBNActiv(nout // 2, nout // 4, 1, 1, 0) ) - self.stg1_high_band_net = BaseNet(2, 16, self.nin_lstm // 2, 64) + + self.stg1_high_band_net = BaseNet(2, nout // 4, self.nin_lstm // 2, nout_lstm // 2) - self.stg2_low_band_net = nn.Sequential( - BaseNet(18, 64, self.nin_lstm // 2, 128), - layers.Conv2DBNActiv(64, 32, 1, 1, 0) + self.stg2_low_band_net = nn.Sequential( + BaseNet(nout // 4 + 2, nout, self.nin_lstm // 2, nout_lstm), + layers.Conv2DBNActiv(nout, nout // 2, 1, 1, 0) ) - self.stg2_high_band_net = BaseNet(18, 32, self.nin_lstm // 2, 64) + self.stg2_high_band_net = BaseNet(nout // 4 + 2, nout // 2, self.nin_lstm // 2, nout_lstm // 2) - self.stg3_full_band_net = BaseNet(50, 64, self.nin_lstm, 128) + self.stg3_full_band_net = BaseNet(3 * nout // 4 + 2, nout, self.nin_lstm, nout_lstm) - self.out = nn.Conv2d(64, 2, 1, bias=False) - self.aux_out = nn.Conv2d(48, 2, 1, bias=False) - else: - self.stg1_low_band_net = nn.Sequential( - BaseNet(2, 16, self.nin_lstm // 2, 128), - layers.Conv2DBNActiv(16, 8, 1, 1, 0) - ) - self.stg1_high_band_net = BaseNet(2, 8, self.nin_lstm // 2, 64) - - self.stg2_low_band_net = nn.Sequential( - BaseNet(10, 32, self.nin_lstm // 2, 128), - layers.Conv2DBNActiv(32, 16, 1, 1, 0) - ) - self.stg2_high_band_net = BaseNet(10, 16, self.nin_lstm // 2, 64) - - self.stg3_full_band_net = BaseNet(26, 32, self.nin_lstm, 128) - - self.out = nn.Conv2d(32, 2, 1, bias=False) - self.aux_out = nn.Conv2d(24, 2, 1, bias=False) + self.out = nn.Conv2d(nout, 2, 1, bias=False) + self.aux_out = nn.Conv2d(3 * nout // 4, 2, 1, bias=False) def forward(self, x): x = x[:, :, :self.max_bin]