Add files via upload

This commit is contained in:
Anjok07
2022-06-03 04:08:37 -05:00
committed by GitHub
parent 3f62d7e877
commit aa2dd10834
5 changed files with 845 additions and 361 deletions

View File

@@ -71,8 +71,10 @@ class Predictor():
self.onnx_models = {}
c = 0
self.models = get_models('tdf_extra', load=False, device=cpu, stems=modeltype)
widget_text.write(base_text + 'Loading ONNX model... ')
self.models = get_models('tdf_extra', load=False, device=cpu, stems=modeltype, n_fft_scale=n_fft_scale_set, dim_f=dim_f_set)
if not data['demucs_only']:
widget_text.write(base_text + 'Loading ONNX model... ')
update_progress(**progress_kwargs,
step=0.1)
c+=1
@@ -92,8 +94,10 @@ class Predictor():
print('model_set: ', model_set)
self.onnx_models[c] = ort.InferenceSession(os.path.join('models/MDX_Net_Models', model_set), providers=run_type)
widget_text.write('Done!\n')
if not data['demucs_only']:
widget_text.write('Done!\n')
def prediction(self, m):
#mix, rate = sf.read(m)
mix, rate = librosa.load(m, mono=False, sr=44100)
@@ -181,7 +185,7 @@ class Predictor():
step=(0.9))
widget_text.write('Done!\n')
widget_text.write(base_text + 'Performing Noise Reduction... ')
reduction_sen = float(int(data['noisereduc_s'])/10)
reduction_sen = float(data['noisereduc_s'])/10
subprocess.call("lib_v5\\sox\\sox.exe" + ' "' +
f"{str(non_reduced_vocal_path)}" + '" "' + f"{str(vocal_path)}" + '" ' +
"noisered lib_v5\\sox\\mdxnetnoisereduc.prof " + f"{reduction_sen}",
@@ -353,7 +357,8 @@ class Predictor():
if not data['demucsmodel']:
sources = self.demix_base(segmented_mix, margin_size=margin)
elif data['demucs_only']:
sources = self.demix_demucs(segmented_mix, margin_size=margin)
else: # both, apply spec effects
base_out = self.demix_base(segmented_mix, margin_size=margin)
demucs_out = self.demix_demucs(segmented_mix, margin_size=margin)
@@ -364,8 +369,8 @@ class Predictor():
sources = {}
sources[3] = (spec_effects(wave=[demucs_out[3],base_out[0]],
algorithm='default',
value=b[3])*1.03597672895) # compensation
algorithm=data['mixing'],
value=b[3])*float(data['compensate'])) # compensation
return sources
def demix_base(self, mixes, margin_size):
@@ -439,7 +444,6 @@ class Predictor():
cmix = torch.tensor(cmix, dtype=torch.float32)
ref = cmix.mean(0)
cmix = (cmix - ref.mean()) / ref.std()
shift_set = 0
with torch.no_grad():
sources = apply_model(self.demucs, cmix.to(device), split=True, overlap=overlap_set, shifts=shift_set)
sources = (sources * ref.std() + ref.mean()).cpu().numpy()
@@ -562,11 +566,20 @@ data = {
'chunks': 'auto',
'non_red': False,
'noisereduc_s': 3,
'mixing': 'default',
'ensChoose': 'Basic Ensemble',
'algo': 'Instrumentals (Min Spec)',
#Advanced Options
'appendensem': False,
'overlap': 0.5,
'shifts': 0,
'margin': 44100,
'channel': 64,
'compensate': 1.03597672895,
'demucs_only': False,
'mixing': 'Default',
'DemucsModel': 'demucs_extra-3646af93_org.th',
# Models
'instrumentalModel': None,
'useModel': None,
@@ -602,9 +615,6 @@ def main(window: tk.Wm, text_widget: tk.Text, button_widget: tk.Button, progress
global widget_text
global gui_progress_bar
global music_file
global channel_set
global margin_set
global overlap_set
global default_chunks
global default_noisereduc_s
global base_name
@@ -615,13 +625,17 @@ def main(window: tk.Wm, text_widget: tk.Text, button_widget: tk.Button, progress
global model_set_name
global ModelName_2
global channel_set
global margin_set
global overlap_set
global shift_set
global n_fft_scale_set
global dim_f_set
# Update default settings
default_chunks = data['chunks']
default_noisereduc_s = data['noisereduc_s']
channel_set = int(64)
margin_set = int(44100)
overlap_set = float(0.5)
widget_text = text_widget
gui_progress_bar = progress_var
@@ -647,6 +661,15 @@ def main(window: tk.Wm, text_widget: tk.Text, button_widget: tk.Button, progress
except:
pass
overlap_set = float(data['overlap'])
channel_set = int(data['channel'])
margin_set = int(data['margin'])
shift_set = int(data['shifts'])
n_fft_scale_set=6144
dim_f_set=2048
global nn_arch_sizes
global nn_architecture
@@ -1195,11 +1218,11 @@ def main(window: tk.Wm, text_widget: tk.Text, button_widget: tk.Button, progress
#MDX-Net Model
if data['mdx_ensem'] == 'UVR-MDX-NET 1':
mdx_ensem = 'UVR_MDXNET_9703'
mdx_ensem = 'UVR_MDXNET_1_9703'
if data['mdx_ensem'] == 'UVR-MDX-NET 2':
mdx_ensem = 'UVR_MDXNET_9682'
mdx_ensem = 'UVR_MDXNET_2_9682'
if data['mdx_ensem'] == 'UVR-MDX-NET 3':
mdx_ensem = 'UVR_MDXNET_9662'
mdx_ensem = 'UVR_MDXNET_3_9662'
if data['mdx_ensem'] == 'UVR-MDX-NET Karaoke':
mdx_ensem = 'UVR_MDXNET_KARA'
@@ -1207,11 +1230,11 @@ def main(window: tk.Wm, text_widget: tk.Text, button_widget: tk.Button, progress
#MDX-Net Model 2
if data['mdx_ensem_b'] == 'UVR-MDX-NET 1':
mdx_ensem_b = 'UVR_MDXNET_9703'
mdx_ensem_b = 'UVR_MDXNET_1_9703'
if data['mdx_ensem_b'] == 'UVR-MDX-NET 2':
mdx_ensem_b = 'UVR_MDXNET_9682'
mdx_ensem_b = 'UVR_MDXNET_2_9682'
if data['mdx_ensem_b'] == 'UVR-MDX-NET 3':
mdx_ensem_b = 'UVR_MDXNET_9662'
mdx_ensem_b = 'UVR_MDXNET_3_9662'
if data['mdx_ensem_b'] == 'UVR-MDX-NET Karaoke':
mdx_ensem_b = 'UVR_MDXNET_Karaoke'
if data['mdx_ensem_b'] == 'No Model':
@@ -1925,23 +1948,23 @@ def main(window: tk.Wm, text_widget: tk.Text, button_widget: tk.Button, progress
else:
text_widget.write('Ensemble Mode - Running Model - ' + mdx_name + '\n\n')
if mdx_name == 'UVR_MDXNET_9703':
mdx_ensem_b = 'UVR_MDXNET_9703'
model_set = 'UVR_MDXNET_9703.onnx'
model_set_name = 'UVR_MDXNET_9703'
modeltype = 'vocals-one'
if mdx_name == 'UVR_MDXNET_9682':
model_set = 'UVR_MDXNET_9682.onnx'
model_set_name = 'UVR_MDXNET_9682'
modeltype = 'vocals-one'
if mdx_name == 'UVR_MDXNET_9662':
model_set = 'UVR_MDXNET_9662.onnx'
model_set_name = 'UVR_MDXNET_9662'
modeltype = 'vocals-one'
if mdx_name == 'UVR_MDXNET_1_9703':
mdx_ensem_b = 'UVR_MDXNET_1_9703'
model_set = 'UVR_MDXNET_1_9703.onnx'
model_set_name = 'UVR_MDXNET_1_9703'
modeltype = 'v'
if mdx_name == 'UVR_MDXNET_2_9682':
model_set = 'UVR_MDXNET_2_9682.onnx'
model_set_name = 'UVR_MDXNET_2_9682'
modeltype = 'v'
if mdx_name == 'UVR_MDXNET_3_9662':
model_set = 'UVR_MDXNET_3_9662.onnx'
model_set_name = 'UVR_MDXNET_3_9662'
modeltype = 'v'
if mdx_name == 'UVR_MDXNET_Karaoke':
model_set = 'UVR_MDXNET_KARA.onnx'
model_set_name = 'UVR_MDXNET_Karaoke'
modeltype = 'vocals-one'
modeltype = 'v'
update_progress(**progress_kwargs,
step=0)
@@ -1958,7 +1981,7 @@ def main(window: tk.Wm, text_widget: tk.Text, button_widget: tk.Button, progress
e = os.path.join(data["export_path"])
demucsmodel = 'models/Demucs_Model/demucs_extra-3646af93_org.th'
demucsmodel = 'models/Demucs_Model/' + str(data['DemucsModel'])
pred = Predictor()
pred.prediction_setup(demucs_name=demucsmodel,