From 6a1b9891193e4e2349dcb70ddd80a7fe626340e7 Mon Sep 17 00:00:00 2001 From: Anjok07 <68268275+Anjok07@users.noreply.github.com> Date: Wed, 2 Nov 2022 03:24:31 -0500 Subject: [PATCH] Delete models.py --- models.py | 129 ------------------------------------------------------ 1 file changed, 129 deletions(-) delete mode 100644 models.py diff --git a/models.py b/models.py deleted file mode 100644 index 77f6a02..0000000 --- a/models.py +++ /dev/null @@ -1,129 +0,0 @@ -import torch -from torch._C import has_mkl -import torch.nn as nn -import numpy as np -import librosa - -dim_c = 4 -model_path = 'model' - -class Conv_TDF_net_trim(nn.Module): - def __init__(self, device, n_fft_scale, dim_f, load, model_name, target_name, - L, dim_t, hop=1024): - - super(Conv_TDF_net_trim, self).__init__() - - self.dim_f, self.dim_t = dim_f, 2**dim_t - self.n_fft = n_fft_scale - self.hop = hop - self.n_bins = self.n_fft//2+1 - self.chunk_size = hop * (self.dim_t-1) - self.window = torch.hann_window(window_length=self.n_fft, periodic=False).to(device) - self.target_name = target_name - #print(n_fft_scale) - out_c = dim_c*4 if target_name=='*' else dim_c - self.freq_pad = torch.zeros([1, out_c, self.n_bins-self.dim_f, self.dim_t]).to(device) - - def stft(self, x): - x = x.reshape([-1, self.chunk_size]) - x = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True) - x = x.permute([0,3,1,2]) - x = x.reshape([-1,2,2,self.n_bins,self.dim_t]).reshape([-1,dim_c,self.n_bins,self.dim_t]) - return x[:,:,:self.dim_f] - - def istft(self, x, freq_pad=None): - freq_pad = self.freq_pad.repeat([x.shape[0],1,1,1]) if freq_pad is None else freq_pad - x = torch.cat([x, freq_pad], -2) - c = 4*2 if self.target_name=='*' else 2 - x = x.reshape([-1,c,2,self.n_bins,self.dim_t]).reshape([-1,2,self.n_bins,self.dim_t]) - x = x.permute([0,2,3,1]) - x = torch.istft(x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True) - return x.reshape([-1,c,self.chunk_size]) - -def stft(wave, nfft, hl): - wave_left = np.asfortranarray(wave[0]) - wave_right = np.asfortranarray(wave[1]) - spec_left = librosa.stft(wave_left, nfft, hop_length=hl) - spec_right = librosa.stft(wave_right, nfft, hop_length=hl) - spec = np.asfortranarray([spec_left, spec_right]) - - return spec - -def istft(spec, hl): - spec_left = np.asfortranarray(spec[0]) - spec_right = np.asfortranarray(spec[1]) - wave_left = librosa.istft(spec_left, hop_length=hl) - wave_right = librosa.istft(spec_right, hop_length=hl) - wave = np.asfortranarray([wave_left, wave_right]) - - return wave - -def spec_effects(wave, algorithm='Default', value=None): - doubleout = spec = [stft(wave[0],2048,1024),stft(wave[1],2048,1024)] - if algorithm == 'Min_Mag': - doubleout - v_spec_m = np.where(np.abs(spec[1]) <= np.abs(spec[0]), spec[1], spec[0]) - wave = istft(v_spec_m,1024) - elif algorithm == 'Max_Mag': - doubleout - v_spec_m = np.where(np.abs(spec[1]) >= np.abs(spec[0]), spec[1], spec[0]) - wave = istft(v_spec_m,1024) - elif algorithm == 'Default': - doubleout - #wave = [istft(spec[0],1024),istft(spec[1],1024)] - wave = (wave[1] * value) + (wave[0] * (1-value)) - elif algorithm == 'Invert_p': - doubleout - X_mag = np.abs(spec[0]) - y_mag = np.abs(spec[1]) - max_mag = np.where(X_mag >= y_mag, X_mag, y_mag) - v_spec = spec[1] - max_mag * np.exp(1.j * np.angle(spec[0])) - wave = istft(v_spec,1024) - return wave - - -def get_models(name, device, n_fft_scale, dim_f, load=True, stems='bdov'): - - if name=='tdf_extra': - models = [] - if 'b' in stems: - models.append( - Conv_TDF_net_trim( - device=device, load=load, n_fft_scale=n_fft_scale, - model_name='Conv-TDF', target_name='bass', - L=11, dim_f=dim_f, dim_t=8 - ) - ) - if 'd' in stems: - models.append( - Conv_TDF_net_trim( - device=device, load=load, n_fft_scale=n_fft_scale, - model_name='Conv-TDF', target_name='drums', - L=9, dim_f=dim_f, dim_t=7 - ) - ) - if 'o' in stems: - models.append( - Conv_TDF_net_trim( - device=device, load=load, n_fft_scale=n_fft_scale, - model_name='Conv-TDF', target_name='other', - L=11, dim_f=dim_f, dim_t=8 - ) - ) - if 'v' in stems: - models.append( - Conv_TDF_net_trim( - device=device, load=load, n_fft_scale=n_fft_scale, - model_name='Conv-TDF', target_name='vocals', - L=11, dim_f=dim_f, dim_t=8 - ) - ) - - return models - - else: - print('Model undefined') - return None - - -