Add files via upload
This commit is contained in:
@@ -40,50 +40,32 @@ class BaseNet(nn.Module):
|
||||
|
||||
class CascadedNet(nn.Module):
|
||||
|
||||
def __init__(self, n_fft, nn_architecture):
|
||||
def __init__(self, n_fft, nn_arch_size, nout=32, nout_lstm=128):
|
||||
super(CascadedNet, self).__init__()
|
||||
|
||||
self.max_bin = n_fft // 2
|
||||
self.output_bin = n_fft // 2 + 1
|
||||
self.nin_lstm = self.max_bin // 2
|
||||
self.offset = 64
|
||||
self.nn_architecture = nn_architecture
|
||||
nout = 64 if nn_arch_size == 218409 else nout
|
||||
|
||||
print('ARC SIZE: ', nn_architecture)
|
||||
|
||||
if nn_architecture == 218409:
|
||||
self.stg1_low_band_net = nn.Sequential(
|
||||
BaseNet(2, 32, self.nin_lstm // 2, 128),
|
||||
layers.Conv2DBNActiv(32, 16, 1, 1, 0)
|
||||
self.stg1_low_band_net = nn.Sequential(
|
||||
BaseNet(2, nout // 2, self.nin_lstm // 2, nout_lstm),
|
||||
layers.Conv2DBNActiv(nout // 2, nout // 4, 1, 1, 0)
|
||||
)
|
||||
self.stg1_high_band_net = BaseNet(2, 16, self.nin_lstm // 2, 64)
|
||||
|
||||
self.stg1_high_band_net = BaseNet(2, nout // 4, self.nin_lstm // 2, nout_lstm // 2)
|
||||
|
||||
self.stg2_low_band_net = nn.Sequential(
|
||||
BaseNet(18, 64, self.nin_lstm // 2, 128),
|
||||
layers.Conv2DBNActiv(64, 32, 1, 1, 0)
|
||||
self.stg2_low_band_net = nn.Sequential(
|
||||
BaseNet(nout // 4 + 2, nout, self.nin_lstm // 2, nout_lstm),
|
||||
layers.Conv2DBNActiv(nout, nout // 2, 1, 1, 0)
|
||||
)
|
||||
self.stg2_high_band_net = BaseNet(18, 32, self.nin_lstm // 2, 64)
|
||||
self.stg2_high_band_net = BaseNet(nout // 4 + 2, nout // 2, self.nin_lstm // 2, nout_lstm // 2)
|
||||
|
||||
self.stg3_full_band_net = BaseNet(50, 64, self.nin_lstm, 128)
|
||||
self.stg3_full_band_net = BaseNet(3 * nout // 4 + 2, nout, self.nin_lstm, nout_lstm)
|
||||
|
||||
self.out = nn.Conv2d(64, 2, 1, bias=False)
|
||||
self.aux_out = nn.Conv2d(48, 2, 1, bias=False)
|
||||
else:
|
||||
self.stg1_low_band_net = nn.Sequential(
|
||||
BaseNet(2, 16, self.nin_lstm // 2, 128),
|
||||
layers.Conv2DBNActiv(16, 8, 1, 1, 0)
|
||||
)
|
||||
self.stg1_high_band_net = BaseNet(2, 8, self.nin_lstm // 2, 64)
|
||||
|
||||
self.stg2_low_band_net = nn.Sequential(
|
||||
BaseNet(10, 32, self.nin_lstm // 2, 128),
|
||||
layers.Conv2DBNActiv(32, 16, 1, 1, 0)
|
||||
)
|
||||
self.stg2_high_band_net = BaseNet(10, 16, self.nin_lstm // 2, 64)
|
||||
|
||||
self.stg3_full_band_net = BaseNet(26, 32, self.nin_lstm, 128)
|
||||
|
||||
self.out = nn.Conv2d(32, 2, 1, bias=False)
|
||||
self.aux_out = nn.Conv2d(24, 2, 1, bias=False)
|
||||
self.out = nn.Conv2d(nout, 2, 1, bias=False)
|
||||
self.aux_out = nn.Conv2d(3 * nout // 4, 2, 1, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = x[:, :, :self.max_bin]
|
||||
|
||||
Reference in New Issue
Block a user