From 535225172e20ce94350d0d04ddb202e0bc820d06 Mon Sep 17 00:00:00 2001 From: 233lol <18070600+233lol@users.noreply.github.com> Date: Thu, 6 Apr 2023 09:55:10 +0800 Subject: [PATCH] fix stft and istft in pyotrch 2.0.0 fix stft and istft in pyotrch 2.0.0 in pytorch 2.0.0 not support real output(stft)and real input(istft) --- separate.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/separate.py b/separate.py index 23e8a12..c1fcabc 100644 --- a/separate.py +++ b/separate.py @@ -392,7 +392,8 @@ class SeperateMDX(SeperateAttributes): def stft(self, x): x = x.reshape([-1, self.chunk_size]) - x = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True) + x = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True,return_complex=True) + x=torch.view_as_real(x) x = x.permute([0,3,1,2]) x = x.reshape([-1,2,2,self.n_bins,self.dim_t]).reshape([-1,self.dim_c,self.n_bins,self.dim_t]) return x[:,:,:self.dim_f] @@ -402,6 +403,8 @@ class SeperateMDX(SeperateAttributes): x = torch.cat([x, freq_pad], -2) x = x.reshape([-1,2,2,self.n_bins,self.dim_t]).reshape([-1,2,self.n_bins,self.dim_t]) x = x.permute([0,2,3,1]) + x=x.contiguous() + x=torch.view_as_complex(x) x = torch.istft(x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True) return x.reshape([-1,2,self.chunk_size]) @@ -936,4 +939,4 @@ def save_format(audio_path, save_format, mp3_bit_set): try: os.remove(audio_path) except Exception as e: - print(e) \ No newline at end of file + print(e)