Add files via upload
This commit is contained in:
341
separate.py
341
separate.py
@@ -12,6 +12,7 @@ from lib_v5.vr_network import nets_new
|
||||
#from lib_v5.vr_network.model_param_init import ModelParameters
|
||||
from pathlib import Path
|
||||
from gui_data.constants import *
|
||||
from gui_data.error_handling import *
|
||||
import audioread
|
||||
import gzip
|
||||
import librosa
|
||||
@@ -23,6 +24,8 @@ import torch
|
||||
import warnings
|
||||
import pydub
|
||||
import soundfile as sf
|
||||
import traceback
|
||||
import lib_v5.mdxnet as MdxnetSet
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from UVR import ModelData
|
||||
@@ -46,7 +49,10 @@ class SeperateAttributes:
|
||||
self.is_4_stem_ensemble = process_data['is_4_stem_ensemble']
|
||||
self.list_all_models = process_data['list_all_models']
|
||||
self.process_iteration = process_data['process_iteration']
|
||||
self.mixer_path = model_data.mixer_path
|
||||
self.model_samplerate = model_data.model_samplerate
|
||||
self.model_capacity = model_data.model_capacity
|
||||
self.is_vr_51_model = model_data.is_vr_51_model
|
||||
self.is_pre_proc_model = model_data.is_pre_proc_model
|
||||
self.is_secondary_model_activated = model_data.is_secondary_model_activated if not self.is_pre_proc_model else False
|
||||
self.is_secondary_model = model_data.is_secondary_model if not self.is_pre_proc_model else True
|
||||
@@ -67,6 +73,7 @@ class SeperateAttributes:
|
||||
self.primary_stem = model_data.primary_stem #
|
||||
self.secondary_stem = model_data.secondary_stem #
|
||||
self.is_invert_spec = model_data.is_invert_spec #
|
||||
self.is_mixer_mode = model_data.is_mixer_mode #
|
||||
self.secondary_model_scale = model_data.secondary_model_scale #
|
||||
self.is_demucs_pre_proc_model_inst_mix = model_data.is_demucs_pre_proc_model_inst_mix #
|
||||
self.primary_source_map = {}
|
||||
@@ -94,21 +101,24 @@ class SeperateAttributes:
|
||||
self.is_secondary_stem_only = True if self.secondary_stem == INST_STEM else False
|
||||
|
||||
if model_data.process_method == MDX_ARCH_TYPE:
|
||||
self.is_mdx_ckpt = model_data.is_mdx_ckpt
|
||||
self.primary_model_name, self.primary_sources = self.cached_source_callback(MDX_ARCH_TYPE, model_name=self.model_basename)
|
||||
self.is_denoise = model_data.is_denoise
|
||||
self.mdx_batch_size = model_data.mdx_batch_size
|
||||
self.compensate = model_data.compensate
|
||||
self.dim_f, self.dim_t = model_data.mdx_dim_f_set, 2**model_data.mdx_dim_t_set
|
||||
self.n_fft = model_data.mdx_n_fft_scale_set
|
||||
self.chunks = model_data.chunks
|
||||
self.margin = model_data.margin
|
||||
self.hop = 1024
|
||||
self.n_bins = self.n_fft//2+1
|
||||
self.chunk_size = self.hop * (self.dim_t-1)
|
||||
self.window = torch.hann_window(window_length=self.n_fft, periodic=False).to(cpu)
|
||||
self.adjust = 1
|
||||
self.dim_c = 4
|
||||
out_c = self.dim_c
|
||||
self.freq_pad = torch.zeros([1, out_c, self.n_bins-self.dim_f, self.dim_t]).to(cpu)
|
||||
|
||||
self.hop = 1024
|
||||
|
||||
if self.is_gpu_conversion >= 0 and torch.cuda.is_available():
|
||||
self.device, self.run_type = torch.device('cuda:0'), ['CUDAExecutionProvider']
|
||||
else:
|
||||
self.device, self.run_type = torch.device('cpu'), ['CPUExecutionProvider']
|
||||
|
||||
if model_data.process_method == DEMUCS_ARCH_TYPE:
|
||||
self.demucs_stems = model_data.demucs_stems if not main_process_method in [MDX_ARCH_TYPE, VR_ARCH_TYPE] else None
|
||||
self.secondary_model_4_stem = model_data.secondary_model_4_stem
|
||||
@@ -152,15 +162,14 @@ class SeperateAttributes:
|
||||
self.is_post_process = model_data.is_post_process
|
||||
self.is_gpu_conversion = model_data.is_gpu_conversion
|
||||
self.batch_size = model_data.batch_size
|
||||
self.crop_size = model_data.crop_size
|
||||
self.window_size = model_data.window_size
|
||||
self.input_high_end_h = None
|
||||
self.post_process_threshold = model_data.post_process_threshold
|
||||
self.aggressiveness = {'value': model_data.aggression_setting,
|
||||
'split_bin': self.mp.param['band'][1]['crop_stop'],
|
||||
'aggr_correction': self.mp.param.get('aggr_correction')}
|
||||
|
||||
def start_inference(self):
|
||||
|
||||
def start_inference_console_write(self):
|
||||
|
||||
if self.is_secondary_model and not self.is_pre_proc_model:
|
||||
self.write_to_console(INFERENCE_STEP_2_SEC(self.process_method, self.model_basename))
|
||||
@@ -168,7 +177,7 @@ class SeperateAttributes:
|
||||
if self.is_pre_proc_model:
|
||||
self.write_to_console(INFERENCE_STEP_2_PRE(self.process_method, self.model_basename))
|
||||
|
||||
def running_inference(self, is_no_write=False):
|
||||
def running_inference_console_write(self, is_no_write=False):
|
||||
|
||||
self.write_to_console(DONE, base_text='') if not is_no_write else None
|
||||
self.set_progress_bar(0.05) if not is_no_write else None
|
||||
@@ -180,6 +189,15 @@ class SeperateAttributes:
|
||||
else:
|
||||
self.write_to_console(INFERENCE_STEP_1)
|
||||
|
||||
def running_inference_progress_bar(self, length, is_match_mix=False):
|
||||
if not is_match_mix:
|
||||
self.progress_value += 1
|
||||
|
||||
if (0.8/length*self.progress_value) >= 0.8:
|
||||
length = self.progress_value + 1
|
||||
|
||||
self.set_progress_bar(0.1, (0.8/length*self.progress_value))
|
||||
|
||||
def load_cached_sources(self, is_4_stem_demucs=False):
|
||||
|
||||
if self.is_secondary_model and not self.is_pre_proc_model:
|
||||
@@ -222,31 +240,51 @@ class SeperateAttributes:
|
||||
self.write_to_console(DONE, base_text='')
|
||||
self.set_progress_bar(0.95)
|
||||
|
||||
def run_mixer(self, mix, sources):
|
||||
try:
|
||||
if self.is_mixer_mode and len(sources) == 4:
|
||||
mixer = MdxnetSet.Mixer(self.device, self.mixer_path).eval()
|
||||
with torch.no_grad():
|
||||
mix = torch.tensor(mix, dtype=torch.float32)
|
||||
sources_ = torch.tensor(sources).detach()
|
||||
x = torch.cat([sources_, mix.unsqueeze(0)], 0)
|
||||
sources_ = mixer(x)
|
||||
final_source = np.array(sources_)
|
||||
else:
|
||||
final_source = sources
|
||||
except Exception as e:
|
||||
error_name = f'{type(e).__name__}'
|
||||
traceback_text = ''.join(traceback.format_tb(e.__traceback__))
|
||||
message = f'{error_name}: "{e}"\n{traceback_text}"'
|
||||
print('Mixer Failed: ', message)
|
||||
final_source = sources
|
||||
|
||||
return final_source
|
||||
|
||||
class SeperateMDX(SeperateAttributes):
|
||||
|
||||
def seperate(self):
|
||||
|
||||
samplerate = 44100
|
||||
|
||||
|
||||
if self.primary_model_name == self.model_basename and self.primary_sources:
|
||||
self.primary_source, self.secondary_source = self.load_cached_sources()
|
||||
else:
|
||||
self.start_inference()
|
||||
self.start_inference_console_write()
|
||||
|
||||
if self.is_gpu_conversion >= 0:
|
||||
self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
||||
run_type = ['CUDAExecutionProvider'] if torch.cuda.is_available() else ['CPUExecutionProvider']
|
||||
if self.is_mdx_ckpt:
|
||||
model_params = torch.load(self.model_path, map_location=lambda storage, loc: storage)['hyper_parameters']
|
||||
self.dim_c, self.hop = model_params['dim_c'], model_params['hop_length']
|
||||
separator = MdxnetSet.ConvTDFNet(**model_params)
|
||||
self.model_run = separator.load_from_checkpoint(self.model_path).to(self.device).eval()
|
||||
else:
|
||||
self.device = torch.device('cpu')
|
||||
run_type = ['CPUExecutionProvider']
|
||||
ort_ = ort.InferenceSession(self.model_path, providers=self.run_type)
|
||||
self.model_run = lambda spek:ort_.run(None, {'input': spek.cpu().numpy()})[0]
|
||||
|
||||
self.onnx_model = ort.InferenceSession(self.model_path, providers=run_type)
|
||||
|
||||
self.running_inference()
|
||||
self.initialize_model_settings()
|
||||
self.running_inference_console_write()
|
||||
mdx_net_cut = True if self.primary_stem in MDX_NET_FREQ_CUT else False
|
||||
mix, raw_mix, samplerate = prepare_mix(self.audio_file, self.chunks, self.margin, mdx_net_cut=mdx_net_cut)
|
||||
|
||||
source = self.demix_base(mix)
|
||||
source = self.demix_base(mix, is_ckpt=self.is_mdx_ckpt)[0]
|
||||
self.write_to_console(DONE, base_text='')
|
||||
|
||||
if self.is_secondary_model_activated:
|
||||
@@ -257,7 +295,7 @@ class SeperateMDX(SeperateAttributes):
|
||||
self.write_to_console(f'{SAVING_STEM[0]}{self.primary_stem}{SAVING_STEM[1]}') if not self.is_secondary_model else None
|
||||
primary_stem_path = os.path.join(self.export_path, f'{self.audio_file_base}_({self.primary_stem}).wav')
|
||||
if not isinstance(self.primary_source, np.ndarray):
|
||||
self.primary_source = spec_utils.normalize(source[0], self.is_normalization).T
|
||||
self.primary_source = spec_utils.normalize(source, self.is_normalization).T
|
||||
self.primary_source_map = {self.primary_stem: self.primary_source}
|
||||
self.write_audio(primary_stem_path, self.primary_source, samplerate, self.secondary_source_primary)
|
||||
|
||||
@@ -266,7 +304,7 @@ class SeperateMDX(SeperateAttributes):
|
||||
secondary_stem_path = os.path.join(self.export_path, f'{self.audio_file_base}_({self.secondary_stem}).wav')
|
||||
if not isinstance(self.secondary_source, np.ndarray):
|
||||
raw_mix = self.demix_base(raw_mix, is_match_mix=True)[0] if mdx_net_cut else raw_mix
|
||||
self.secondary_source, raw_mix = spec_utils.normalize_two_stem(source[0]*self.compensate, raw_mix, self.is_normalization)
|
||||
self.secondary_source, raw_mix = spec_utils.normalize_two_stem(source*self.compensate, raw_mix, self.is_normalization)
|
||||
|
||||
if self.is_invert_spec:
|
||||
self.secondary_source = spec_utils.invert_stem(raw_mix, self.secondary_source)
|
||||
@@ -277,7 +315,6 @@ class SeperateMDX(SeperateAttributes):
|
||||
self.write_audio(secondary_stem_path, self.secondary_source, samplerate, self.secondary_source_secondary)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
secondary_sources = {**self.primary_source_map, **self.secondary_source_map}
|
||||
|
||||
self.cache_source(secondary_sources)
|
||||
@@ -285,53 +322,73 @@ class SeperateMDX(SeperateAttributes):
|
||||
if self.is_secondary_model:
|
||||
return secondary_sources
|
||||
|
||||
def demix_base(self, mix, is_match_mix=False):
|
||||
chunked_sources = []
|
||||
def initialize_model_settings(self):
|
||||
self.n_bins = self.n_fft//2+1
|
||||
self.trim = self.n_fft//2
|
||||
self.chunk_size = self.hop * (self.dim_t-1)
|
||||
self.window = torch.hann_window(window_length=self.n_fft, periodic=False).to(self.device)
|
||||
self.freq_pad = torch.zeros([1, self.dim_c, self.n_bins-self.dim_f, self.dim_t]).to(self.device)
|
||||
self.gen_size = self.chunk_size-2*self.trim
|
||||
|
||||
for slice in mix:
|
||||
self.progress_value += 1
|
||||
self.set_progress_bar(0.1, (0.8/len(mix)*self.progress_value)) if not is_match_mix else None
|
||||
cmix = mix[slice]
|
||||
sources = []
|
||||
def initialize_mix(self, mix, is_ckpt=False):
|
||||
if is_ckpt:
|
||||
pad = self.gen_size + self.trim - ((mix.shape[-1]) % self.gen_size)
|
||||
mixture = np.concatenate((np.zeros((2, self.trim), dtype='float32'),mix, np.zeros((2, pad), dtype='float32')), 1)
|
||||
num_chunks = mixture.shape[-1] // self.gen_size
|
||||
mix_waves = [mixture[:, i * self.gen_size: i * self.gen_size + self.chunk_size] for i in range(num_chunks)]
|
||||
else:
|
||||
mix_waves = []
|
||||
n_sample = cmix.shape[1]
|
||||
trim = self.n_fft//2
|
||||
gen_size = self.chunk_size-2*trim
|
||||
pad = gen_size - n_sample%gen_size
|
||||
mix_p = np.concatenate((np.zeros((2,trim)), cmix, np.zeros((2,pad)), np.zeros((2,trim))), 1)
|
||||
n_sample = mix.shape[1]
|
||||
pad = self.gen_size - n_sample%self.gen_size
|
||||
mix_p = np.concatenate((np.zeros((2,self.trim)), mix, np.zeros((2,pad)), np.zeros((2,self.trim))), 1)
|
||||
i = 0
|
||||
while i < n_sample + pad:
|
||||
waves = np.array(mix_p[:, i:i+self.chunk_size])
|
||||
mix_waves.append(waves)
|
||||
i += gen_size
|
||||
mix_waves = torch.tensor(mix_waves, dtype=torch.float32).to(cpu)
|
||||
with torch.no_grad():
|
||||
_ort = self.onnx_model if not is_match_mix else None
|
||||
adjust = 1
|
||||
spek = self.stft(mix_waves)*adjust
|
||||
i += self.gen_size
|
||||
|
||||
if not is_match_mix:
|
||||
if self.is_denoise:
|
||||
spec_pred = -_ort.run(None, {'input': -spek.cpu().numpy()})[0]*0.5+_ort.run(None, {'input': spek.cpu().numpy()})[0]*0.5
|
||||
else:
|
||||
spec_pred = _ort.run(None, {'input': spek.cpu().numpy()})[0]
|
||||
else:
|
||||
spec_pred = spek.cpu().numpy()
|
||||
mix_waves = torch.tensor(mix_waves, dtype=torch.float32).to(self.device)
|
||||
|
||||
tar_waves = self.istft(torch.tensor(spec_pred))#.cpu()
|
||||
tar_signal = tar_waves[:,:,trim:-trim].transpose(0,1).reshape(2, -1).numpy()[:, :-pad]
|
||||
return mix_waves, pad
|
||||
|
||||
def demix_base(self, mix, is_ckpt=False, is_match_mix=False):
|
||||
chunked_sources = []
|
||||
for slice in mix:
|
||||
sources = []
|
||||
tar_waves_ = []
|
||||
mix_p = mix[slice]
|
||||
mix_waves, pad = self.initialize_mix(mix_p, is_ckpt=is_ckpt)
|
||||
mix_waves = mix_waves.split(self.mdx_batch_size)
|
||||
pad = mix_p.shape[-1] if is_ckpt else -pad
|
||||
with torch.no_grad():
|
||||
for mix_wave in mix_waves:
|
||||
self.running_inference_progress_bar(len(mix)*len(mix_waves), is_match_mix=is_match_mix)
|
||||
tar_waves = self.run_model(mix_wave, is_ckpt=is_ckpt, is_match_mix=is_match_mix)
|
||||
tar_waves_.append(tar_waves)
|
||||
tar_waves_ = np.vstack(tar_waves_)[:, :, self.trim:-self.trim] if is_ckpt else tar_waves_
|
||||
tar_waves = np.concatenate(tar_waves_, axis=-1)[:, :pad]
|
||||
start = 0 if slice == 0 else self.margin
|
||||
end = None if slice == list(mix.keys())[::-1][0] else -self.margin
|
||||
if self.margin == 0:
|
||||
end = None
|
||||
sources.append(tar_signal[:,start:end]*(1/adjust))
|
||||
end = None if slice == list(mix.keys())[::-1][0] or self.margin == 0 else -self.margin
|
||||
sources.append(tar_waves[:,start:end]*(1/self.adjust))
|
||||
chunked_sources.append(sources)
|
||||
sources = np.concatenate(chunked_sources, axis=-1)
|
||||
|
||||
if not is_match_mix:
|
||||
del self.onnx_model
|
||||
|
||||
|
||||
return sources
|
||||
|
||||
def run_model(self, mix, is_ckpt=False, is_match_mix=False):
|
||||
|
||||
spek = self.stft(mix.to(self.device))*self.adjust
|
||||
spek[:, :, :3, :] *= 0
|
||||
|
||||
if is_match_mix:
|
||||
spec_pred = spek.cpu().numpy()
|
||||
else:
|
||||
spec_pred = -self.model_run(-spek)*0.5+self.model_run(spek)*0.5 if self.is_denoise else self.model_run(spek)
|
||||
|
||||
if is_ckpt:
|
||||
return self.istft(spec_pred).cpu().detach().numpy()
|
||||
else:
|
||||
return self.istft(torch.tensor(spec_pred).to(self.device)).to(cpu)[:,:,self.trim:-self.trim].transpose(0,1).reshape(2, -1).numpy()
|
||||
|
||||
def stft(self, x):
|
||||
x = x.reshape([-1, self.chunk_size])
|
||||
@@ -343,11 +400,10 @@ class SeperateMDX(SeperateAttributes):
|
||||
def istft(self, x, freq_pad=None):
|
||||
freq_pad = self.freq_pad.repeat([x.shape[0],1,1,1]) if freq_pad is None else freq_pad
|
||||
x = torch.cat([x, freq_pad], -2)
|
||||
c = 2
|
||||
x = x.reshape([-1,c,2,self.n_bins,self.dim_t]).reshape([-1,2,self.n_bins,self.dim_t])
|
||||
x = x.reshape([-1,2,2,self.n_bins,self.dim_t]).reshape([-1,2,self.n_bins,self.dim_t])
|
||||
x = x.permute([0,2,3,1])
|
||||
x = torch.istft(x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True)
|
||||
return x.reshape([-1,c,self.chunk_size])
|
||||
return x.reshape([-1,2,self.chunk_size])
|
||||
|
||||
class SeperateDemucs(SeperateAttributes):
|
||||
|
||||
@@ -371,7 +427,7 @@ class SeperateDemucs(SeperateAttributes):
|
||||
source = self.primary_sources
|
||||
self.load_cached_sources(is_4_stem_demucs=True)
|
||||
else:
|
||||
self.start_inference()
|
||||
self.start_inference_console_write()
|
||||
|
||||
if self.is_gpu_conversion >= 0:
|
||||
self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
||||
@@ -405,22 +461,25 @@ class SeperateDemucs(SeperateAttributes):
|
||||
mix_no_voc = process_secondary_model(self.pre_proc_model, self.process_data, is_pre_proc_model=True)
|
||||
inst_mix, inst_raw_mix, inst_samplerate = prepare_mix(mix_no_voc[INST_STEM], self.chunks_demucs, self.margin_demucs)
|
||||
self.process_iteration()
|
||||
self.running_inference(is_no_write=is_no_write)
|
||||
self.running_inference_console_write(is_no_write=is_no_write)
|
||||
inst_source = self.demix_demucs(inst_mix)
|
||||
inst_source = self.run_mixer(inst_raw_mix, inst_source)
|
||||
self.process_iteration()
|
||||
|
||||
self.running_inference(is_no_write=is_no_write) if not self.pre_proc_model else None
|
||||
self.running_inference_console_write(is_no_write=is_no_write) if not self.pre_proc_model else None
|
||||
mix, raw_mix, samplerate = prepare_mix(self.audio_file, self.chunks_demucs, self.margin_demucs)
|
||||
|
||||
if self.primary_model_name == self.model_basename and isinstance(self.primary_sources, np.ndarray) and self.pre_proc_model:
|
||||
source = self.primary_sources
|
||||
else:
|
||||
source = self.demix_demucs(mix)
|
||||
source = self.run_mixer(raw_mix, source)
|
||||
|
||||
self.write_to_console(DONE, base_text='')
|
||||
|
||||
del self.demucs
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if isinstance(inst_source, np.ndarray):
|
||||
source_reshape = spec_utils.reshape_sources(inst_source[self.demucs_source_map[VOCAL_STEM]], source[self.demucs_source_map[VOCAL_STEM]])
|
||||
inst_source[self.demucs_source_map[VOCAL_STEM]] = source_reshape
|
||||
@@ -431,6 +490,7 @@ class SeperateDemucs(SeperateAttributes):
|
||||
self.demucs_source_map = DEMUCS_2_SOURCE_MAPPER
|
||||
else:
|
||||
self.demucs_source_map = DEMUCS_6_SOURCE_MAPPER if len(source) == 6 else DEMUCS_4_SOURCE_MAPPER
|
||||
|
||||
if len(source) == 6 and self.process_data['is_ensemble_master'] or len(source) == 6 and self.is_secondary_model:
|
||||
is_no_piano_guitar = True
|
||||
six_stem_other_source = list(source)
|
||||
@@ -445,7 +505,6 @@ class SeperateDemucs(SeperateAttributes):
|
||||
self.cache_source(source)
|
||||
|
||||
for stem_name, stem_value in self.demucs_source_map.items():
|
||||
|
||||
if self.is_secondary_model_activated and not self.is_secondary_model and not stem_value >= 4:
|
||||
if self.secondary_model_4_stem[stem_value]:
|
||||
model_scale = self.secondary_model_4_stem_scale[stem_value]
|
||||
@@ -520,9 +579,7 @@ class SeperateDemucs(SeperateAttributes):
|
||||
|
||||
if self.is_demucs_pre_proc_model_inst_mix and self.pre_proc_model and not self.is_4_stem_ensemble:
|
||||
secondary_save(f"{self.secondary_stem} {INST_STEM}", source, raw_mixture=inst_raw_mix, is_inst_mixture=True)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
secondary_sources = {**self.primary_source_map, **self.secondary_source_map}
|
||||
|
||||
self.cache_source(secondary_sources)
|
||||
@@ -583,11 +640,10 @@ class SeperateDemucs(SeperateAttributes):
|
||||
class SeperateVR(SeperateAttributes):
|
||||
|
||||
def seperate(self):
|
||||
|
||||
if self.primary_model_name == self.model_basename and self.primary_sources:
|
||||
self.primary_source, self.secondary_source = self.load_cached_sources()
|
||||
else:
|
||||
self.start_inference()
|
||||
self.start_inference_console_write()
|
||||
if self.is_gpu_conversion >= 0:
|
||||
if OPERATING_SYSTEM == 'Darwin':
|
||||
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
|
||||
@@ -595,32 +651,27 @@ class SeperateVR(SeperateAttributes):
|
||||
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
||||
else:
|
||||
device = torch.device('cpu')
|
||||
|
||||
|
||||
nn_arch_sizes = [
|
||||
31191, # default
|
||||
33966, 56817, 218409, 123821, 123812, 129605, 537238, 537227]
|
||||
33966, 56817, 123821, 123812, 129605, 218409, 537238, 537227]
|
||||
vr_5_1_models = [56817, 218409]
|
||||
|
||||
model_size = math.ceil(os.stat(self.model_path).st_size / 1024)
|
||||
nn_architecture = min(nn_arch_sizes, key=lambda x:abs(x-model_size))
|
||||
nn_arch_size = min(nn_arch_sizes, key=lambda x:abs(x-model_size))
|
||||
|
||||
if nn_architecture in vr_5_1_models:
|
||||
model = nets_new.CascadedNet(self.mp.param['bins'] * 2, nn_architecture)
|
||||
inference = self.inference_vr_new
|
||||
if nn_arch_size in vr_5_1_models or self.is_vr_51_model:
|
||||
self.model_run = nets_new.CascadedNet(self.mp.param['bins'] * 2, nn_arch_size, nout=self.model_capacity[0], nout_lstm=self.model_capacity[1])
|
||||
else:
|
||||
model = nets.determine_model_capacity(self.mp.param['bins'] * 2, nn_architecture)
|
||||
inference = self.inference_vr
|
||||
self.model_run = nets.determine_model_capacity(self.mp.param['bins'] * 2, nn_arch_size)
|
||||
|
||||
self.model_run.load_state_dict(torch.load(self.model_path, map_location=cpu))
|
||||
self.model_run.to(device)
|
||||
|
||||
model.load_state_dict(torch.load(self.model_path, map_location=device))
|
||||
model.to(device)
|
||||
|
||||
self.running_inference()
|
||||
|
||||
y_spec, v_spec = inference(self.loading_mix(), device, model, self.aggressiveness)
|
||||
self.running_inference_console_write()
|
||||
|
||||
y_spec, v_spec = self.inference_vr(self.loading_mix(), device, self.aggressiveness)
|
||||
self.write_to_console(DONE, base_text='')
|
||||
|
||||
del model
|
||||
|
||||
if self.is_secondary_model_activated:
|
||||
if self.secondary_model:
|
||||
self.secondary_source_primary, self.secondary_source_secondary = process_secondary_model(self.secondary_model, self.process_data, main_process_method=self.process_method)
|
||||
@@ -649,9 +700,8 @@ class SeperateVR(SeperateAttributes):
|
||||
self.secondary_source_map = {self.secondary_stem: self.secondary_source}
|
||||
|
||||
self.write_audio(secondary_stem_path, self.secondary_source, 44100, self.secondary_source_secondary)
|
||||
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
secondary_sources = {**self.primary_source_map, **self.secondary_source_map}
|
||||
self.cache_source(secondary_sources)
|
||||
|
||||
@@ -696,80 +746,20 @@ class SeperateVR(SeperateAttributes):
|
||||
|
||||
return X_spec
|
||||
|
||||
def inference_vr(self, X_spec, device, model, aggressiveness):
|
||||
|
||||
def _execute(X_mag_pad, roi_size, n_window, device, model, aggressiveness):
|
||||
model.eval()
|
||||
|
||||
total_iterations = sum([n_window]) if not self.is_tta else sum([n_window])*2
|
||||
|
||||
with torch.no_grad():
|
||||
preds = []
|
||||
|
||||
for i in range(n_window):
|
||||
self.progress_value +=1
|
||||
self.set_progress_bar(0.1, 0.8/total_iterations*self.progress_value)
|
||||
start = i * roi_size
|
||||
X_mag_window = X_mag_pad[None, :, :, start:start + self.window_size]
|
||||
X_mag_window = torch.from_numpy(X_mag_window).to(device)
|
||||
pred = model.predict(X_mag_window, aggressiveness)
|
||||
pred = pred.detach().cpu().numpy()
|
||||
preds.append(pred[0])
|
||||
|
||||
pred = np.concatenate(preds, axis=2)
|
||||
return pred
|
||||
|
||||
X_mag, X_phase = spec_utils.preprocess(X_spec)
|
||||
coef = X_mag.max()
|
||||
X_mag_pre = X_mag / coef
|
||||
n_frame = X_mag_pre.shape[2]
|
||||
pad_l, pad_r, roi_size = spec_utils.make_padding(n_frame, self.window_size, model.offset)
|
||||
n_window = int(np.ceil(n_frame / roi_size))
|
||||
X_mag_pad = np.pad(
|
||||
X_mag_pre, ((0, 0), (0, 0), (pad_l, pad_r)), mode='constant')
|
||||
pred = _execute(X_mag_pad, roi_size, n_window, device, model, aggressiveness)
|
||||
pred = pred[:, :, :n_frame]
|
||||
|
||||
if self.is_tta:
|
||||
pad_l += roi_size // 2
|
||||
pad_r += roi_size // 2
|
||||
n_window += 1
|
||||
X_mag_pad = np.pad(X_mag_pre, ((0, 0), (0, 0), (pad_l, pad_r)), mode='constant')
|
||||
pred_tta = _execute(X_mag_pad, roi_size, n_window, device, model, aggressiveness)
|
||||
pred_tta = pred_tta[:, :, roi_size // 2:]
|
||||
pred_tta = pred_tta[:, :, :n_frame]
|
||||
pred, X_mag, X_phase = (pred + pred_tta) * 0.5 * coef, X_mag, np.exp(1.j * X_phase)
|
||||
else:
|
||||
pred, X_mag, X_phase = pred * coef, X_mag, np.exp(1.j * X_phase)
|
||||
|
||||
if self.is_post_process:
|
||||
pred_inv = np.clip(X_mag - pred, 0, np.inf)
|
||||
pred = spec_utils.mask_silence(pred, pred_inv, thres=self.post_process_threshold)
|
||||
|
||||
y_spec = pred * X_phase
|
||||
v_spec = X_spec - y_spec
|
||||
|
||||
return y_spec, v_spec
|
||||
|
||||
def inference_vr_new(self, X_spec, device, model, aggressiveness):
|
||||
|
||||
def inference_vr(self, X_spec, device, aggressiveness):
|
||||
def _execute(X_mag_pad, roi_size):
|
||||
|
||||
X_dataset = []
|
||||
patches = (X_mag_pad.shape[2] - 2 * model.offset) // roi_size
|
||||
patches = (X_mag_pad.shape[2] - 2 * self.model_run.offset) // roi_size
|
||||
total_iterations = patches//self.batch_size if not self.is_tta else (patches//self.batch_size)*2
|
||||
|
||||
for i in range(patches):
|
||||
start = i * roi_size
|
||||
X_mag_crop = X_mag_pad[:, :, start:start + self.crop_size]
|
||||
X_dataset.append(X_mag_crop)
|
||||
X_mag_window = X_mag_pad[:, :, start:start + self.window_size]
|
||||
X_dataset.append(X_mag_window)
|
||||
|
||||
X_dataset = np.asarray(X_dataset)
|
||||
model.eval()
|
||||
|
||||
self.model_run.eval()
|
||||
with torch.no_grad():
|
||||
mask = []
|
||||
# To reduce the overhead, dataloader is not used.
|
||||
for i in range(0, patches, self.batch_size):
|
||||
self.progress_value += 1
|
||||
if self.progress_value >= total_iterations:
|
||||
@@ -777,33 +767,37 @@ class SeperateVR(SeperateAttributes):
|
||||
self.set_progress_bar(0.1, 0.8/total_iterations*self.progress_value)
|
||||
X_batch = X_dataset[i: i + self.batch_size]
|
||||
X_batch = torch.from_numpy(X_batch).to(device)
|
||||
pred = model.predict_mask(X_batch)
|
||||
pred = self.model_run.predict_mask(X_batch)
|
||||
if not pred.size()[3] > 0:
|
||||
raise Exception(ERROR_MAPPER[WINDOW_SIZE_ERROR])
|
||||
pred = pred.detach().cpu().numpy()
|
||||
pred = np.concatenate(pred, axis=2)
|
||||
mask.append(pred)
|
||||
|
||||
if len(mask) == 0:
|
||||
raise Exception(ERROR_MAPPER[WINDOW_SIZE_ERROR])
|
||||
|
||||
mask = np.concatenate(mask, axis=2)
|
||||
|
||||
return mask
|
||||
|
||||
def postprocess(mask, X_mag, X_phase, aggressiveness):
|
||||
|
||||
def postprocess(mask, X_mag, X_phase):
|
||||
|
||||
if self.primary_stem == VOCAL_STEM:
|
||||
mask = (1.0 - spec_utils.adjust_aggr(mask, True, aggressiveness))
|
||||
else:
|
||||
mask = spec_utils.adjust_aggr(mask, False, aggressiveness)
|
||||
is_non_accom_stem = False
|
||||
for stem in NON_ACCOM_STEMS:
|
||||
if stem == self.primary_stem:
|
||||
is_non_accom_stem = True
|
||||
|
||||
mask = spec_utils.adjust_aggr(mask, is_non_accom_stem, aggressiveness)
|
||||
|
||||
if self.is_post_process:
|
||||
mask = spec_utils.merge_artifacts(mask)
|
||||
mask = spec_utils.merge_artifacts(mask, thres=self.post_process_threshold)
|
||||
|
||||
y_spec = mask * X_mag * np.exp(1.j * X_phase)
|
||||
v_spec = (1 - mask) * X_mag * np.exp(1.j * X_phase)
|
||||
|
||||
return y_spec, v_spec
|
||||
|
||||
X_mag, X_phase = spec_utils.preprocess(X_spec)
|
||||
n_frame = X_mag.shape[2]
|
||||
pad_l, pad_r, roi_size = spec_utils.make_padding(n_frame, self.crop_size, model.offset)
|
||||
pad_l, pad_r, roi_size = spec_utils.make_padding(n_frame, self.window_size, self.model_run.offset)
|
||||
X_mag_pad = np.pad(X_mag, ((0, 0), (0, 0), (pad_l, pad_r)), mode='constant')
|
||||
X_mag_pad /= X_mag_pad.max()
|
||||
mask = _execute(X_mag_pad, roi_size)
|
||||
@@ -819,7 +813,7 @@ class SeperateVR(SeperateAttributes):
|
||||
else:
|
||||
mask = mask[:, :, :n_frame]
|
||||
|
||||
y_spec, v_spec = postprocess(mask, X_mag, X_phase, aggressiveness)
|
||||
y_spec, v_spec = postprocess(mask, X_mag, X_phase)
|
||||
|
||||
return y_spec, v_spec
|
||||
|
||||
@@ -889,6 +883,7 @@ def prepare_mix(mix, chunk_set, margin_set, mdx_net_cut=False, is_missing_mix=Fa
|
||||
margin = margin_set
|
||||
chunk_size = chunk_set*44100
|
||||
assert not margin == 0, 'margin cannot be zero!'
|
||||
|
||||
if margin > chunk_size:
|
||||
margin = chunk_size
|
||||
if chunk_set == 0 or samples < chunk_size:
|
||||
@@ -941,4 +936,4 @@ def save_format(audio_path, save_format, mp3_bit_set):
|
||||
try:
|
||||
os.remove(audio_path)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print(e)
|
||||
Reference in New Issue
Block a user