Add files via upload
This commit is contained in:
429
inference_MDX.py
429
inference_MDX.py
@@ -35,6 +35,7 @@ import pydub
|
||||
import shutil
|
||||
import soundfile as sf
|
||||
import subprocess
|
||||
from UVR import MainWindow
|
||||
import sys
|
||||
import time
|
||||
import time # Timer
|
||||
@@ -61,45 +62,50 @@ class Predictor():
|
||||
self.noise_pro_select_set_var = tk.StringVar(value='MDX-NET_Noise_Profile_14_kHz')
|
||||
self.compensate_v_var = tk.StringVar(value=1.03597672895)
|
||||
|
||||
top= Toplevel()
|
||||
mdx_model_set = Toplevel()
|
||||
|
||||
top.geometry("740x550")
|
||||
window_height = 740
|
||||
window_width = 550
|
||||
mdx_model_set.geometry("490x515")
|
||||
window_height = 490
|
||||
window_width = 515
|
||||
|
||||
top.title("Specify Parameters")
|
||||
mdx_model_set.title("Specify Parameters")
|
||||
|
||||
top.resizable(False, False) # This code helps to disable windows from resizing
|
||||
mdx_model_set.resizable(False, False) # This code helps to disable windows from resizing
|
||||
|
||||
top.attributes("-topmost", True)
|
||||
mdx_model_set.attributes("-topmost", True)
|
||||
|
||||
screen_width = top.winfo_screenwidth()
|
||||
screen_height = top.winfo_screenheight()
|
||||
screen_width = mdx_model_set.winfo_screenwidth()
|
||||
screen_height = mdx_model_set.winfo_screenheight()
|
||||
|
||||
x_cordinate = int((screen_width/2) - (window_width/2))
|
||||
y_cordinate = int((screen_height/2) - (window_height/2))
|
||||
|
||||
top.geometry("{}x{}+{}+{}".format(window_width, window_height, x_cordinate, y_cordinate))
|
||||
mdx_model_set.geometry("{}x{}+{}+{}".format(window_width, window_height, x_cordinate, y_cordinate))
|
||||
|
||||
x = main_window.winfo_x()
|
||||
y = main_window.winfo_y()
|
||||
mdx_model_set.geometry("+%d+%d" %(x+50,y+150))
|
||||
mdx_model_set.wm_transient(main_window)
|
||||
|
||||
# change title bar icon
|
||||
top.iconbitmap('img\\UVR-Icon-v2.ico')
|
||||
mdx_model_set.iconbitmap('img\\UVR-Icon-v2.ico')
|
||||
|
||||
tabControl = ttk.Notebook(top)
|
||||
mdx_model_set_window = ttk.Notebook(mdx_model_set)
|
||||
|
||||
tabControl.pack(expand = 1, fill ="both")
|
||||
mdx_model_set_window.pack(expand = 1, fill ="both")
|
||||
|
||||
tabControl.grid_rowconfigure(0, weight=1)
|
||||
tabControl.grid_columnconfigure(0, weight=1)
|
||||
mdx_model_set_window.grid_rowconfigure(0, weight=1)
|
||||
mdx_model_set_window.grid_columnconfigure(0, weight=1)
|
||||
|
||||
frame0=Frame(tabControl,highlightbackground='red',highlightthicknes=0)
|
||||
frame0=Frame(mdx_model_set_window,highlightbackground='red',highlightthicknes=0)
|
||||
frame0.grid(row=0,column=0,padx=0,pady=0)
|
||||
|
||||
frame0.tkraise(frame0)
|
||||
#frame0.tkraise(frame0)
|
||||
|
||||
space_small = ' '*20
|
||||
space_small_1 = ' '*10
|
||||
|
||||
l0=tk.Label(frame0, text=f'{space_small}Stem Type{space_small}', font=("Century Gothic", "9"), foreground='#13a4c9')
|
||||
l0=tk.Label(frame0, text=f'\n{space_small}Stem Type{space_small}', font=("Century Gothic", "9"), foreground='#13a4c9')
|
||||
l0.grid(row=3,column=0,padx=0,pady=5)
|
||||
|
||||
l0=ttk.OptionMenu(frame0, self.mdxnetModeltype_var, None, 'Vocals', 'Instrumental', 'Other', 'Bass', 'Drums')
|
||||
@@ -160,18 +166,15 @@ class Predictor():
|
||||
torch.cuda.empty_cache()
|
||||
gui_progress_bar.set(0)
|
||||
widget_button.configure(state=tk.NORMAL) # Enable Button
|
||||
top.destroy()
|
||||
self.okVar.set(1)
|
||||
stop_button()
|
||||
mdx_model_set.destroy()
|
||||
return
|
||||
|
||||
l0=ttk.Button(frame0,text="Stop Process", command=stop)
|
||||
l0.grid(row=13,column=1,padx=0,pady=30)
|
||||
|
||||
def change_event():
|
||||
self.okVar.set(1)
|
||||
#top.destroy()
|
||||
pass
|
||||
|
||||
top.protocol("WM_DELETE_WINDOW", change_event)
|
||||
mdx_model_set.protocol("WM_DELETE_WINDOW", stop)
|
||||
|
||||
frame0.wait_variable(self.okVar)
|
||||
|
||||
@@ -217,13 +220,13 @@ class Predictor():
|
||||
stem_text_b = 'Vocals'
|
||||
elif stemset_n == '(Other)':
|
||||
stem_text_a = 'Other'
|
||||
stem_text_b = 'the no \"Other\" track'
|
||||
stem_text_b = 'mixture without selected stem'
|
||||
elif stemset_n == '(Drums)':
|
||||
stem_text_a = 'Drums'
|
||||
stem_text_b = 'no \"Drums\" track'
|
||||
stem_text_b = 'mixture without selected stem'
|
||||
elif stemset_n == '(Bass)':
|
||||
stem_text_a = 'Bass'
|
||||
stem_text_b = 'No \"Bass\" track'
|
||||
stem_text_b = 'mixture without selected stem'
|
||||
else:
|
||||
stem_text_a = 'Vocals'
|
||||
stem_text_b = 'Instrumental'
|
||||
@@ -263,7 +266,7 @@ class Predictor():
|
||||
widget_text.write(base_text + 'Setting Demucs model to \"UVR_Demucs_Model_1\".\n\n')
|
||||
demucs_model_set = 'UVR_Demucs_Model_1'
|
||||
|
||||
top.destroy()
|
||||
mdx_model_set.destroy()
|
||||
|
||||
def prediction_setup(self):
|
||||
|
||||
@@ -287,6 +290,10 @@ class Predictor():
|
||||
self.demucs.to(device)
|
||||
self.demucs.load_state_dict(state)
|
||||
widget_text.write('Done!\n')
|
||||
if not data['segment'] == 'Default':
|
||||
widget_text.write(base_text + 'Segments is only available in Demucs v3. Please use \"Chunks\" instead.\n')
|
||||
else:
|
||||
pass
|
||||
|
||||
if demucs_model_version == 'v2':
|
||||
if '48' in demucs_model_set:
|
||||
@@ -306,6 +313,10 @@ class Predictor():
|
||||
self.demucs.to(device)
|
||||
self.demucs.load_state_dict(torch.load("models/Demucs_Models/"f"{demucs_model_set}"))
|
||||
widget_text.write('Done!\n')
|
||||
if not data['segment'] == 'Default':
|
||||
widget_text.write(base_text + 'Segments is only available in Demucs v3. Please use \"Chunks\" instead.\n')
|
||||
else:
|
||||
pass
|
||||
self.demucs.eval()
|
||||
|
||||
if demucs_model_version == 'v3':
|
||||
@@ -324,6 +335,37 @@ class Predictor():
|
||||
widget_text.write('Done!\n')
|
||||
if isinstance(self.demucs, BagOfModels):
|
||||
widget_text.write(base_text + f"Selected Demucs model is a bag of {len(self.demucs.models)} model(s).\n")
|
||||
|
||||
if data['segment'] == 'Default':
|
||||
segment = None
|
||||
if isinstance(self.demucs, BagOfModels):
|
||||
if segment is not None:
|
||||
for sub in self.demucs.models:
|
||||
sub.segment = segment
|
||||
else:
|
||||
if segment is not None:
|
||||
sub.segment = segment
|
||||
else:
|
||||
try:
|
||||
segment = int(data['segment'])
|
||||
if isinstance(self.demucs, BagOfModels):
|
||||
if segment is not None:
|
||||
for sub in self.demucs.models:
|
||||
sub.segment = segment
|
||||
else:
|
||||
if segment is not None:
|
||||
sub.segment = segment
|
||||
if split_mode:
|
||||
widget_text.write(base_text + "Segments set to "f"{segment}.\n")
|
||||
except:
|
||||
segment = None
|
||||
if isinstance(self.demucs, BagOfModels):
|
||||
if segment is not None:
|
||||
for sub in self.demucs.models:
|
||||
sub.segment = segment
|
||||
else:
|
||||
if segment is not None:
|
||||
sub.segment = segment
|
||||
|
||||
self.onnx_models = {}
|
||||
c = 0
|
||||
@@ -394,13 +436,13 @@ class Predictor():
|
||||
if data['modelFolder']:
|
||||
vocal_path = '{save_path}/{file_name}.wav'.format(
|
||||
save_path=save_path,
|
||||
file_name = f'{os.path.basename(_basename)}_{vocal_name}_{model_set_name}',)
|
||||
file_name = f'{os.path.basename(_basename)}_{vocal_name}_{mdx_model_name}',)
|
||||
vocal_path_mp3 = '{save_path}/{file_name}.mp3'.format(
|
||||
save_path=save_path,
|
||||
file_name = f'{os.path.basename(_basename)}_{vocal_name}_{model_set_name}',)
|
||||
file_name = f'{os.path.basename(_basename)}_{vocal_name}_{mdx_model_name}',)
|
||||
vocal_path_flac = '{save_path}/{file_name}.flac'.format(
|
||||
save_path=save_path,
|
||||
file_name = f'{os.path.basename(_basename)}_{vocal_name}_{model_set_name}',)
|
||||
file_name = f'{os.path.basename(_basename)}_{vocal_name}_{mdx_model_name}',)
|
||||
else:
|
||||
vocal_path = '{save_path}/{file_name}.wav'.format(
|
||||
save_path=save_path,
|
||||
@@ -428,13 +470,13 @@ class Predictor():
|
||||
if data['modelFolder']:
|
||||
Instrumental_path = '{save_path}/{file_name}.wav'.format(
|
||||
save_path=save_path,
|
||||
file_name = f'{os.path.basename(_basename)}_{Instrumental_name}_{model_set_name}',)
|
||||
file_name = f'{os.path.basename(_basename)}_{Instrumental_name}_{mdx_model_name}',)
|
||||
Instrumental_path_mp3 = '{save_path}/{file_name}.mp3'.format(
|
||||
save_path=save_path,
|
||||
file_name = f'{os.path.basename(_basename)}_{Instrumental_name}_{model_set_name}',)
|
||||
file_name = f'{os.path.basename(_basename)}_{Instrumental_name}_{mdx_model_name}',)
|
||||
Instrumental_path_flac = '{save_path}/{file_name}.flac'.format(
|
||||
save_path=save_path,
|
||||
file_name = f'{os.path.basename(_basename)}_{Instrumental_name}_{model_set_name}',)
|
||||
file_name = f'{os.path.basename(_basename)}_{Instrumental_name}_{mdx_model_name}',)
|
||||
else:
|
||||
Instrumental_path = '{save_path}/{file_name}.wav'.format(
|
||||
save_path=save_path,
|
||||
@@ -461,13 +503,13 @@ class Predictor():
|
||||
if data['modelFolder']:
|
||||
non_reduced_vocal_path = '{save_path}/{file_name}.wav'.format(
|
||||
save_path=save_path,
|
||||
file_name = f'{os.path.basename(_basename)}_{vocal_name}_{model_set_name}_No_Reduction',)
|
||||
file_name = f'{os.path.basename(_basename)}_{vocal_name}_{mdx_model_name}_No_Reduction',)
|
||||
non_reduced_vocal_path_mp3 = '{save_path}/{file_name}.mp3'.format(
|
||||
save_path=save_path,
|
||||
file_name = f'{os.path.basename(_basename)}_{vocal_name}_{model_set_name}_No_Reduction',)
|
||||
file_name = f'{os.path.basename(_basename)}_{vocal_name}_{mdx_model_name}_No_Reduction',)
|
||||
non_reduced_vocal_path_flac = '{save_path}/{file_name}.flac'.format(
|
||||
save_path=save_path,
|
||||
file_name = f'{os.path.basename(_basename)}_{vocal_name}_{model_set_name}_No_Reduction',)
|
||||
file_name = f'{os.path.basename(_basename)}_{vocal_name}_{mdx_model_name}_No_Reduction',)
|
||||
else:
|
||||
non_reduced_vocal_path = '{save_path}/{file_name}.wav'.format(
|
||||
save_path=save_path,
|
||||
@@ -482,13 +524,13 @@ class Predictor():
|
||||
if data['modelFolder']:
|
||||
non_reduced_Instrumental_path = '{save_path}/{file_name}.wav'.format(
|
||||
save_path=save_path,
|
||||
file_name = f'{os.path.basename(_basename)}_{Instrumental_name}_{model_set_name}_No_Reduction',)
|
||||
file_name = f'{os.path.basename(_basename)}_{Instrumental_name}_{mdx_model_name}_No_Reduction',)
|
||||
non_reduced_Instrumental_path_mp3 = '{save_path}/{file_name}.mp3'.format(
|
||||
save_path=save_path,
|
||||
file_name = f'{os.path.basename(_basename)}_{Instrumental_name}_{model_set_name}_No_Reduction',)
|
||||
file_name = f'{os.path.basename(_basename)}_{Instrumental_name}_{mdx_model_name}_No_Reduction',)
|
||||
non_reduced_Instrumental_path_flac = '{save_path}/{file_name}.flac'.format(
|
||||
save_path=save_path,
|
||||
file_name = f'{os.path.basename(_basename)}_{Instrumental_name}_{model_set_name}_No_Reduction',)
|
||||
file_name = f'{os.path.basename(_basename)}_{Instrumental_name}_{mdx_model_name}_No_Reduction',)
|
||||
else:
|
||||
non_reduced_Instrumental_path = '{save_path}/{file_name}.wav'.format(
|
||||
save_path=save_path,
|
||||
@@ -918,19 +960,21 @@ class Predictor():
|
||||
widget_text.write(base_text + 'Completed Separation!\n')
|
||||
|
||||
def demix(self, mix):
|
||||
global chunk_set
|
||||
|
||||
# 1 = demucs only
|
||||
# 0 = onnx only
|
||||
if data['chunks'] == 'Full':
|
||||
chunk_set = 0
|
||||
else:
|
||||
chunk_set = data['chunks']
|
||||
|
||||
if data['chunks'] == 'Auto':
|
||||
widget_text.write(base_text + "Chunk size user-set to \"Full\"... \n")
|
||||
elif data['chunks'] == 'Auto':
|
||||
if data['gpu'] == 0:
|
||||
try:
|
||||
gpu_mem = round(torch.cuda.get_device_properties(0).total_memory/1.074e+9)
|
||||
except:
|
||||
widget_text.write(base_text + 'NVIDIA GPU Required for conversion!\n')
|
||||
data['gpu'] = -1
|
||||
pass
|
||||
if int(gpu_mem) <= int(6):
|
||||
chunk_set = int(5)
|
||||
widget_text.write(base_text + 'Chunk size auto-set to 5... \n')
|
||||
@@ -954,9 +998,9 @@ class Predictor():
|
||||
if int(sys_mem) >= int(17):
|
||||
chunk_set = int(60)
|
||||
widget_text.write(base_text + 'Chunk size auto-set to 60... \n')
|
||||
elif data['chunks'] == 'Full':
|
||||
elif data['chunks'] == '0':
|
||||
chunk_set = 0
|
||||
widget_text.write(base_text + "Chunk size set to full... \n")
|
||||
widget_text.write(base_text + "Chunk size user-set to \"Full\"... \n")
|
||||
else:
|
||||
chunk_set = int(data['chunks'])
|
||||
widget_text.write(base_text + "Chunk size user-set to "f"{chunk_set}... \n")
|
||||
@@ -986,29 +1030,33 @@ class Predictor():
|
||||
segmented_mix[skip] = mix[:,start:end].copy()
|
||||
if end == samples:
|
||||
break
|
||||
|
||||
|
||||
if not data['demucsmodel']:
|
||||
sources = self.demix_base(segmented_mix, margin_size=margin)
|
||||
elif data['demucs_only']:
|
||||
if split_mode == True:
|
||||
if no_chunk_demucs == False:
|
||||
sources = self.demix_demucs_split(mix)
|
||||
if split_mode == False:
|
||||
if no_chunk_demucs == True:
|
||||
sources = self.demix_demucs(segmented_mix, margin_size=margin)
|
||||
else: # both, apply spec effects
|
||||
base_out = self.demix_base(segmented_mix, margin_size=margin)
|
||||
#print(split_mode)
|
||||
|
||||
|
||||
if demucs_model_version == 'v1':
|
||||
demucs_out = self.demix_demucs_v1(segmented_mix, margin_size=margin)
|
||||
if no_chunk_demucs == False:
|
||||
demucs_out = self.demix_demucs_v1_split(mix)
|
||||
if no_chunk_demucs == True:
|
||||
demucs_out = self.demix_demucs_v1(segmented_mix, margin_size=margin)
|
||||
if demucs_model_version == 'v2':
|
||||
demucs_out = self.demix_demucs_v2(segmented_mix, margin_size=margin)
|
||||
if no_chunk_demucs == False:
|
||||
demucs_out = self.demix_demucs_v2_split(mix)
|
||||
if no_chunk_demucs == True:
|
||||
demucs_out = self.demix_demucs_v2(segmented_mix, margin_size=margin)
|
||||
if demucs_model_version == 'v3':
|
||||
if split_mode == True:
|
||||
if no_chunk_demucs == False:
|
||||
demucs_out = self.demix_demucs_split(mix)
|
||||
if split_mode == False:
|
||||
if no_chunk_demucs == True:
|
||||
demucs_out = self.demix_demucs(segmented_mix, margin_size=margin)
|
||||
|
||||
nan_count = np.count_nonzero(np.isnan(demucs_out)) + np.count_nonzero(np.isnan(base_out))
|
||||
if nan_count > 0:
|
||||
print('Warning: there are {} nan values in the array(s).'.format(nan_count))
|
||||
@@ -1040,10 +1088,15 @@ class Predictor():
|
||||
onnxitera = len(mixes)
|
||||
onnxitera_calc = onnxitera * 2
|
||||
gui_progress_bar_onnx = 0
|
||||
widget_text.write(base_text + "Running ONNX Inference...\n")
|
||||
widget_text.write(base_text + "Processing "f"{onnxitera} slices... ")
|
||||
progress_bar = 0
|
||||
|
||||
print(' Running ONNX Inference...')
|
||||
|
||||
if onnxitera == 1:
|
||||
widget_text.write(base_text + f"Running ONNX Inference... ")
|
||||
else:
|
||||
widget_text.write(base_text + f"Running ONNX Inference...{space}\n")
|
||||
|
||||
for mix in mixes:
|
||||
gui_progress_bar_onnx += 1
|
||||
if data['demucsmodel']:
|
||||
@@ -1053,6 +1106,15 @@ class Predictor():
|
||||
update_progress(**progress_kwargs,
|
||||
step=(0.1 + (0.9/onnxitera * gui_progress_bar_onnx)))
|
||||
|
||||
progress_bar += 100
|
||||
step = (progress_bar / onnxitera)
|
||||
|
||||
if onnxitera == 1:
|
||||
pass
|
||||
else:
|
||||
percent_prog = f"{base_text}MDX-Net Inference Progress: {gui_progress_bar_onnx}/{onnxitera} | {round(step)}%"
|
||||
widget_text.percentage(percent_prog)
|
||||
|
||||
cmix = mixes[mix]
|
||||
sources = []
|
||||
n_sample = cmix.shape[1]
|
||||
@@ -1088,21 +1150,35 @@ class Predictor():
|
||||
chunked_sources.append(sources)
|
||||
_sources = np.concatenate(chunked_sources, axis=-1)
|
||||
del self.onnx_models
|
||||
widget_text.write('Done!\n')
|
||||
|
||||
if onnxitera == 1:
|
||||
widget_text.write('Done!\n')
|
||||
else:
|
||||
widget_text.write('\n')
|
||||
|
||||
return _sources
|
||||
|
||||
def demix_demucs(self, mix, margin_size):
|
||||
#print('shift_set ', shift_set)
|
||||
processed = {}
|
||||
demucsitera = len(mix)
|
||||
demucsitera_calc = demucsitera * 2
|
||||
gui_progress_bar_demucs = 0
|
||||
widget_text.write(base_text + "Split Mode is off. (Chunks enabled for Demucs Model)\n")
|
||||
widget_text.write(base_text + "Running Demucs Inference...\n")
|
||||
widget_text.write(base_text + "Processing "f"{len(mix)} slices... ")
|
||||
progress_bar = 0
|
||||
if demucsitera == 1:
|
||||
widget_text.write(base_text + f"Running Demucs Inference... ")
|
||||
else:
|
||||
widget_text.write(base_text + f"Running Demucs Inference...{space}\n")
|
||||
|
||||
print(' Running Demucs Inference...')
|
||||
for nmix in mix:
|
||||
gui_progress_bar_demucs += 1
|
||||
progress_bar += 100
|
||||
step = (progress_bar / demucsitera)
|
||||
if demucsitera == 1:
|
||||
pass
|
||||
else:
|
||||
percent_prog = f"{base_text}Demucs Inference Progress: {gui_progress_bar_demucs}/{demucsitera} | {round(step)}%"
|
||||
widget_text.percentage(percent_prog)
|
||||
update_progress(**progress_kwargs,
|
||||
step=(0.35 + (1.05/demucsitera_calc * gui_progress_bar_demucs)))
|
||||
cmix = mix[nmix]
|
||||
@@ -1110,8 +1186,17 @@ class Predictor():
|
||||
ref = cmix.mean(0)
|
||||
cmix = (cmix - ref.mean()) / ref.std()
|
||||
with torch.no_grad():
|
||||
#print(split_mode)
|
||||
sources = apply_model(self.demucs, cmix[None], split=split_mode, device=device, overlap=overlap_set, shifts=shift_set, progress=False)[0]
|
||||
sources = apply_model(self.demucs, cmix[None],
|
||||
gui_progress_bar,
|
||||
widget_text,
|
||||
update_prog,
|
||||
split=split_mode,
|
||||
device=device,
|
||||
overlap=overlap_set,
|
||||
shifts=shift_set,
|
||||
progress=False,
|
||||
segmen=False,
|
||||
**progress_demucs_kwargs)[0]
|
||||
sources = (sources * ref.std() + ref.mean()).cpu().numpy()
|
||||
sources[[0,1]] = sources[[1,0]]
|
||||
|
||||
@@ -1123,17 +1208,21 @@ class Predictor():
|
||||
|
||||
sources = list(processed.values())
|
||||
sources = np.concatenate(sources, axis=-1)
|
||||
widget_text.write('Done!\n')
|
||||
|
||||
if demucsitera == 1:
|
||||
widget_text.write('Done!\n')
|
||||
else:
|
||||
widget_text.write('\n')
|
||||
#print('the demucs model is done running')
|
||||
|
||||
return sources
|
||||
|
||||
def demix_demucs_split(self, mix):
|
||||
|
||||
#print('shift_set ', shift_set)
|
||||
widget_text.write(base_text + "Split Mode is on. (Chunks disabled for Demucs Model)\n")
|
||||
widget_text.write(base_text + "Running Demucs Inference...\n")
|
||||
widget_text.write(base_text + "Processing "f"{len(mix)} slices... ")
|
||||
|
||||
if split_mode:
|
||||
widget_text.write(base_text + f"Running Demucs Inference...{space}\n")
|
||||
else:
|
||||
widget_text.write(base_text + f"Running Demucs Inference... ")
|
||||
print(' Running Demucs Inference...')
|
||||
|
||||
mix = torch.tensor(mix, dtype=torch.float32)
|
||||
@@ -1141,14 +1230,26 @@ class Predictor():
|
||||
mix = (mix - ref.mean()) / ref.std()
|
||||
|
||||
with torch.no_grad():
|
||||
sources = apply_model(self.demucs, mix[None], split=split_mode, device=device, overlap=overlap_set, shifts=shift_set, progress=False)[0]
|
||||
sources = apply_model(self.demucs,
|
||||
mix[None],
|
||||
gui_progress_bar,
|
||||
widget_text,
|
||||
update_prog,
|
||||
split=split_mode,
|
||||
device=device,
|
||||
overlap=overlap_set,
|
||||
shifts=shift_set,
|
||||
progress=False,
|
||||
segmen=True,
|
||||
**progress_demucs_kwargs)[0]
|
||||
|
||||
widget_text.write('Done!\n')
|
||||
if split_mode:
|
||||
widget_text.write('\n')
|
||||
else:
|
||||
widget_text.write('Done!\n')
|
||||
|
||||
sources = (sources * ref.std() + ref.mean()).cpu().numpy()
|
||||
sources[[0,1]] = sources[[1,0]]
|
||||
|
||||
#print('the demucs model is done running')
|
||||
|
||||
return sources
|
||||
|
||||
@@ -1157,11 +1258,21 @@ class Predictor():
|
||||
demucsitera = len(mix)
|
||||
demucsitera_calc = demucsitera * 2
|
||||
gui_progress_bar_demucs = 0
|
||||
widget_text.write(base_text + "Running Demucs v1 Inference...\n")
|
||||
widget_text.write(base_text + "Processing "f"{len(mix)} slices... ")
|
||||
progress_bar = 0
|
||||
print(' Running Demucs Inference...')
|
||||
if demucsitera == 1:
|
||||
widget_text.write(base_text + f"Running Demucs v1 Inference... ")
|
||||
else:
|
||||
widget_text.write(base_text + f"Running Demucs v1 Inference...{space}\n")
|
||||
for nmix in mix:
|
||||
gui_progress_bar_demucs += 1
|
||||
progress_bar += 100
|
||||
step = (progress_bar / demucsitera)
|
||||
if demucsitera == 1:
|
||||
pass
|
||||
else:
|
||||
percent_prog = f"{base_text}Demucs v1 Inference Progress: {gui_progress_bar_demucs}/{demucsitera} | {round(step)}%"
|
||||
widget_text.percentage(percent_prog)
|
||||
update_progress(**progress_kwargs,
|
||||
step=(0.35 + (1.05/demucsitera_calc * gui_progress_bar_demucs)))
|
||||
cmix = mix[nmix]
|
||||
@@ -1169,7 +1280,15 @@ class Predictor():
|
||||
ref = cmix.mean(0)
|
||||
cmix = (cmix - ref.mean()) / ref.std()
|
||||
with torch.no_grad():
|
||||
sources = apply_model_v1(self.demucs, cmix.to(device), split=split_mode, shifts=shift_set)
|
||||
sources = apply_model_v1(self.demucs,
|
||||
cmix.to(device),
|
||||
gui_progress_bar,
|
||||
widget_text,
|
||||
update_prog,
|
||||
split=split_mode,
|
||||
segmen=False,
|
||||
shifts=shift_set,
|
||||
**progress_demucs_kwargs)
|
||||
sources = (sources * ref.std() + ref.mean()).cpu().numpy()
|
||||
sources[[0,1]] = sources[[1,0]]
|
||||
|
||||
@@ -1181,7 +1300,44 @@ class Predictor():
|
||||
|
||||
sources = list(processed.values())
|
||||
sources = np.concatenate(sources, axis=-1)
|
||||
widget_text.write('Done!\n')
|
||||
|
||||
if demucsitera == 1:
|
||||
widget_text.write('Done!\n')
|
||||
else:
|
||||
widget_text.write('\n')
|
||||
|
||||
return sources
|
||||
|
||||
def demix_demucs_v1_split(self, mix):
|
||||
|
||||
print(' Running Demucs Inference...')
|
||||
if split_mode:
|
||||
widget_text.write(base_text + f"Running Demucs v1 Inference...{space}\n")
|
||||
else:
|
||||
widget_text.write(base_text + f"Running Demucs v1 Inference... ")
|
||||
|
||||
mix = torch.tensor(mix, dtype=torch.float32)
|
||||
ref = mix.mean(0)
|
||||
mix = (mix - ref.mean()) / ref.std()
|
||||
|
||||
with torch.no_grad():
|
||||
sources = apply_model_v1(self.demucs,
|
||||
mix.to(device),
|
||||
gui_progress_bar,
|
||||
widget_text,
|
||||
update_prog,
|
||||
split=split_mode,
|
||||
segmen=True,
|
||||
shifts=shift_set,
|
||||
**progress_demucs_kwargs)
|
||||
sources = (sources * ref.std() + ref.mean()).cpu().numpy()
|
||||
sources[[0,1]] = sources[[1,0]]
|
||||
|
||||
if split_mode:
|
||||
widget_text.write('\n')
|
||||
else:
|
||||
widget_text.write('Done!\n')
|
||||
|
||||
return sources
|
||||
|
||||
def demix_demucs_v2(self, mix, margin_size):
|
||||
@@ -1189,11 +1345,22 @@ class Predictor():
|
||||
demucsitera = len(mix)
|
||||
demucsitera_calc = demucsitera * 2
|
||||
gui_progress_bar_demucs = 0
|
||||
widget_text.write(base_text + "Running Demucs v2 Inference...\n")
|
||||
widget_text.write(base_text + "Processing "f"{len(mix)} slices... ")
|
||||
print(' Running Demucs Inference...')
|
||||
progress_bar = 0
|
||||
if demucsitera == 1:
|
||||
widget_text.write(base_text + f"Running Demucs v2 Inference... ")
|
||||
else:
|
||||
widget_text.write(base_text + f"Running Demucs v2 Inference...{space}\n")
|
||||
|
||||
for nmix in mix:
|
||||
gui_progress_bar_demucs += 1
|
||||
progress_bar += 100
|
||||
step = (progress_bar / demucsitera)
|
||||
if demucsitera == 1:
|
||||
pass
|
||||
else:
|
||||
percent_prog = f"{base_text}Demucs v2 Inference Progress: {gui_progress_bar_demucs}/{demucsitera} | {round(step)}%"
|
||||
widget_text.percentage(percent_prog)
|
||||
|
||||
update_progress(**progress_kwargs,
|
||||
step=(0.35 + (1.05/demucsitera_calc * gui_progress_bar_demucs)))
|
||||
cmix = mix[nmix]
|
||||
@@ -1201,7 +1368,16 @@ class Predictor():
|
||||
ref = cmix.mean(0)
|
||||
cmix = (cmix - ref.mean()) / ref.std()
|
||||
with torch.no_grad():
|
||||
sources = apply_model_v2(self.demucs, cmix.to(device), split=split_mode, overlap=overlap_set, shifts=shift_set)
|
||||
sources = apply_model_v2(self.demucs,
|
||||
cmix.to(device),
|
||||
gui_progress_bar,
|
||||
widget_text,
|
||||
update_prog,
|
||||
split=split_mode,
|
||||
segmen=False,
|
||||
overlap=overlap_set,
|
||||
shifts=shift_set,
|
||||
**progress_demucs_kwargs)
|
||||
sources = (sources * ref.std() + ref.mean()).cpu().numpy()
|
||||
sources[[0,1]] = sources[[1,0]]
|
||||
|
||||
@@ -1213,9 +1389,46 @@ class Predictor():
|
||||
|
||||
sources = list(processed.values())
|
||||
sources = np.concatenate(sources, axis=-1)
|
||||
widget_text.write('Done!\n')
|
||||
|
||||
if demucsitera == 1:
|
||||
widget_text.write('Done!\n')
|
||||
else:
|
||||
widget_text.write('\n')
|
||||
|
||||
return sources
|
||||
|
||||
def demix_demucs_v2_split(self, mix):
|
||||
print(' Running Demucs Inference...')
|
||||
|
||||
if split_mode:
|
||||
widget_text.write(base_text + f"Running Demucs v2 Inference...{space}\n")
|
||||
else:
|
||||
widget_text.write(base_text + f"Running Demucs v2 Inference... ")
|
||||
|
||||
mix = torch.tensor(mix, dtype=torch.float32)
|
||||
ref = mix.mean(0)
|
||||
mix = (mix - ref.mean()) / ref.std()
|
||||
with torch.no_grad():
|
||||
sources = apply_model_v2(self.demucs,
|
||||
mix.to(device),
|
||||
gui_progress_bar,
|
||||
widget_text,
|
||||
update_prog,
|
||||
split=split_mode,
|
||||
segmen=True,
|
||||
overlap=overlap_set,
|
||||
shifts=shift_set,
|
||||
**progress_demucs_kwargs)
|
||||
|
||||
sources = (sources * ref.std() + ref.mean()).cpu().numpy()
|
||||
sources[[0,1]] = sources[[1,0]]
|
||||
|
||||
if split_mode:
|
||||
widget_text.write('\n')
|
||||
else:
|
||||
widget_text.write('Done!\n')
|
||||
|
||||
return sources
|
||||
|
||||
|
||||
data = {
|
||||
@@ -1240,6 +1453,7 @@ data = {
|
||||
'modelFolder': False,
|
||||
'mp3bit': '320k',
|
||||
'n_fft_scale': 6144,
|
||||
'no_chunk': False,
|
||||
'noise_pro_select': 'Auto Select',
|
||||
'noisereduc_s': 3,
|
||||
'non_red': False,
|
||||
@@ -1247,6 +1461,7 @@ data = {
|
||||
'normalize': False,
|
||||
'overlap': 0.5,
|
||||
'saveFormat': 'Wav',
|
||||
'segment': 'Default',
|
||||
'shifts': 0,
|
||||
'split_mode': False,
|
||||
'voc_only': False,
|
||||
@@ -1286,6 +1501,7 @@ def main(window: tk.Wm,
|
||||
text_widget: tk.Text,
|
||||
button_widget: tk.Button,
|
||||
progress_var: tk.Variable,
|
||||
stop_thread,
|
||||
**kwargs: dict):
|
||||
|
||||
global widget_text
|
||||
@@ -1299,8 +1515,10 @@ def main(window: tk.Wm,
|
||||
global n_fft_scale_set
|
||||
global dim_f_set
|
||||
global progress_kwargs
|
||||
global progress_demucs_kwargs
|
||||
global base_text
|
||||
global model_set_name
|
||||
global mdx_model_name
|
||||
global stemset_n
|
||||
global stem_text_a
|
||||
global stem_text_b
|
||||
@@ -1325,17 +1543,20 @@ def main(window: tk.Wm,
|
||||
global stime
|
||||
global model_hash
|
||||
global demucs_switch
|
||||
global no_chunk_demucs
|
||||
global inst_only
|
||||
global voc_only
|
||||
global space
|
||||
global main_window
|
||||
global stop_button
|
||||
|
||||
|
||||
# Update default settings
|
||||
default_chunks = data['chunks']
|
||||
default_noisereduc_s = data['noisereduc_s']
|
||||
stop_button = stop_thread
|
||||
|
||||
widget_text = text_widget
|
||||
gui_progress_bar = progress_var
|
||||
widget_button = button_widget
|
||||
main_window = window
|
||||
|
||||
|
||||
#Error Handling
|
||||
|
||||
@@ -1361,6 +1582,15 @@ def main(window: tk.Wm,
|
||||
|
||||
data.update(kwargs)
|
||||
|
||||
global update_prog
|
||||
|
||||
# Update default settings
|
||||
update_prog = update_progress
|
||||
default_chunks = data['chunks']
|
||||
default_noisereduc_s = data['noisereduc_s']
|
||||
no_chunk_demucs = data['no_chunk']
|
||||
space = ' '*90
|
||||
|
||||
if data['DemucsModel_MDX'] == "Tasnet v1":
|
||||
demucs_model_set_name = 'tasnet.th'
|
||||
demucs_model_version = 'v1'
|
||||
@@ -1436,6 +1666,10 @@ def main(window: tk.Wm,
|
||||
mdx_model_name = 'UVR_MDXNET_KARA'
|
||||
elif model_set_name == 'UVR-MDX-NET Main':
|
||||
mdx_model_name = 'UVR_MDXNET_Main'
|
||||
elif model_set_name == 'UVR-MDX-NET Inst 1':
|
||||
mdx_model_name = 'UVR_MDXNET_Inst_1'
|
||||
elif model_set_name == 'UVR-MDX-NET Inst 2':
|
||||
mdx_model_name = 'UVR_MDXNET_Inst_2'
|
||||
else:
|
||||
mdx_model_name = data['mdxnetModel']
|
||||
|
||||
@@ -1583,12 +1817,18 @@ def main(window: tk.Wm,
|
||||
_basename = f'{data["export_path"]}/{str(randomnum)}_{file_num}_{os.path.splitext(os.path.basename(music_file))[0]}'
|
||||
else:
|
||||
_basename = f'{data["export_path"]}/{file_num}_{os.path.splitext(os.path.basename(music_file))[0]}'
|
||||
|
||||
|
||||
inference_type = 'inference_mdx'
|
||||
|
||||
# -Get text and update progress-
|
||||
base_text = get_baseText(total_files=len(data['input_paths']),
|
||||
file_num=file_num)
|
||||
progress_kwargs = {'progress_var': progress_var,
|
||||
'total_files': len(data['input_paths']),
|
||||
'file_num': file_num}
|
||||
progress_demucs_kwargs = {'total_files': len(data['input_paths']),
|
||||
'file_num': file_num, 'inference_type': inference_type}
|
||||
|
||||
|
||||
if 'UVR' in demucs_model_set:
|
||||
@@ -1603,10 +1843,11 @@ def main(window: tk.Wm,
|
||||
|
||||
if stemset_n == '(Instrumental)':
|
||||
if not 'UVR' in demucs_model_set:
|
||||
widget_text.write(base_text + 'The selected Demucs model cannot be used with this model.\n')
|
||||
widget_text.write(base_text + 'Only 2 stem Demucs models are compatible with this model.\n')
|
||||
widget_text.write(base_text + 'Setting Demucs model to \"UVR_Demucs_Model_1\".\n\n')
|
||||
demucs_model_set = 'UVR_Demucs_Model_1'
|
||||
if data['demucsmodel']:
|
||||
widget_text.write(base_text + 'The selected Demucs model cannot be used with this model.\n')
|
||||
widget_text.write(base_text + 'Only 2 stem Demucs models are compatible with this model.\n')
|
||||
widget_text.write(base_text + 'Setting Demucs model to \"UVR_Demucs_Model_1\".\n\n')
|
||||
demucs_model_set = 'UVR_Demucs_Model_1'
|
||||
|
||||
try:
|
||||
if float(data['noisereduc_s']) >= 11:
|
||||
@@ -1904,7 +2145,7 @@ def main(window: tk.Wm,
|
||||
text_widget.write(f'\nError Received:\n\n')
|
||||
text_widget.write(f'Could not write audio file.\n')
|
||||
text_widget.write(f'This could be due to low storage on target device or a system permissions issue.\n')
|
||||
text_widget.write(f"\nFor raw error details, go to the Error Log tab in the Help Guide.\n")
|
||||
text_widget.write(f"\nGo to the Settings Menu and click \"Open Error Log\" for raw error details.\n")
|
||||
text_widget.write(f'\nIf the error persists, please contact the developers.\n\n')
|
||||
text_widget.write(f'Time Elapsed: {time.strftime("%H:%M:%S", time.gmtime(int(time.perf_counter() - stime)))}')
|
||||
try:
|
||||
@@ -2013,7 +2254,7 @@ def main(window: tk.Wm,
|
||||
text_widget.write("\n" + base_text + f'Separation failed for the following audio file:\n')
|
||||
text_widget.write(base_text + f'"{os.path.basename(music_file)}"\n')
|
||||
text_widget.write(f'\nError Received:\n')
|
||||
text_widget.write("\nFor raw error details, go to the Error Log tab in the Help Guide.\n")
|
||||
text_widget.write("\nGo to the Settings Menu and click \"Open Error Log\" for raw error details.\n")
|
||||
text_widget.write("\n" + f'Please address the error and try again.' + "\n")
|
||||
text_widget.write(f'If this error persists, please contact the developers with the error details.\n\n')
|
||||
text_widget.write(f'Time Elapsed: {time.strftime("%H:%M:%S", time.gmtime(int(time.perf_counter() - stime)))}')
|
||||
|
||||
Reference in New Issue
Block a user