Add files via upload

This commit is contained in:
Anjok07
2023-04-12 02:13:30 -05:00
committed by GitHub
parent 18d32660db
commit 6ffd7a244e
2 changed files with 21 additions and 44 deletions

View File

@@ -40,50 +40,32 @@ class BaseNet(nn.Module):
class CascadedNet(nn.Module):
def __init__(self, n_fft, nn_architecture):
def __init__(self, n_fft, nn_arch_size, nout=32, nout_lstm=128):
super(CascadedNet, self).__init__()
self.max_bin = n_fft // 2
self.output_bin = n_fft // 2 + 1
self.nin_lstm = self.max_bin // 2
self.offset = 64
self.nn_architecture = nn_architecture
nout = 64 if nn_arch_size == 218409 else nout
print('ARC SIZE: ', nn_architecture)
if nn_architecture == 218409:
self.stg1_low_band_net = nn.Sequential(
BaseNet(2, 32, self.nin_lstm // 2, 128),
layers.Conv2DBNActiv(32, 16, 1, 1, 0)
self.stg1_low_band_net = nn.Sequential(
BaseNet(2, nout // 2, self.nin_lstm // 2, nout_lstm),
layers.Conv2DBNActiv(nout // 2, nout // 4, 1, 1, 0)
)
self.stg1_high_band_net = BaseNet(2, 16, self.nin_lstm // 2, 64)
self.stg1_high_band_net = BaseNet(2, nout // 4, self.nin_lstm // 2, nout_lstm // 2)
self.stg2_low_band_net = nn.Sequential(
BaseNet(18, 64, self.nin_lstm // 2, 128),
layers.Conv2DBNActiv(64, 32, 1, 1, 0)
self.stg2_low_band_net = nn.Sequential(
BaseNet(nout // 4 + 2, nout, self.nin_lstm // 2, nout_lstm),
layers.Conv2DBNActiv(nout, nout // 2, 1, 1, 0)
)
self.stg2_high_band_net = BaseNet(18, 32, self.nin_lstm // 2, 64)
self.stg2_high_band_net = BaseNet(nout // 4 + 2, nout // 2, self.nin_lstm // 2, nout_lstm // 2)
self.stg3_full_band_net = BaseNet(50, 64, self.nin_lstm, 128)
self.stg3_full_band_net = BaseNet(3 * nout // 4 + 2, nout, self.nin_lstm, nout_lstm)
self.out = nn.Conv2d(64, 2, 1, bias=False)
self.aux_out = nn.Conv2d(48, 2, 1, bias=False)
else:
self.stg1_low_band_net = nn.Sequential(
BaseNet(2, 16, self.nin_lstm // 2, 128),
layers.Conv2DBNActiv(16, 8, 1, 1, 0)
)
self.stg1_high_band_net = BaseNet(2, 8, self.nin_lstm // 2, 64)
self.stg2_low_band_net = nn.Sequential(
BaseNet(10, 32, self.nin_lstm // 2, 128),
layers.Conv2DBNActiv(32, 16, 1, 1, 0)
)
self.stg2_high_band_net = BaseNet(10, 16, self.nin_lstm // 2, 64)
self.stg3_full_band_net = BaseNet(26, 32, self.nin_lstm, 128)
self.out = nn.Conv2d(32, 2, 1, bias=False)
self.aux_out = nn.Conv2d(24, 2, 1, bias=False)
self.out = nn.Conv2d(nout, 2, 1, bias=False)
self.aux_out = nn.Conv2d(3 * nout // 4, 2, 1, bias=False)
def forward(self, x):
x = x[:, :, :self.max_bin]