diff --git a/UVR.py b/UVR.py index 5c6c6ff..528eadc 100644 --- a/UVR.py +++ b/UVR.py @@ -53,11 +53,14 @@ from tkinter import * from tkinter.tix import * import re from typing import List +import sys import ssl logging.basicConfig(format='%(asctime)s - %(message)s', level=logging.INFO) logging.info('UVR BEGIN') +PREVIOUS_PATCH_WIN = 'UVR_Patch_1_12_23_14_54' + is_dnd_compatible = True banner_placement = -2 @@ -159,12 +162,15 @@ VR_MODELS_DIR = os.path.join(MODELS_DIR, 'VR_Models') MDX_MODELS_DIR = os.path.join(MODELS_DIR, 'MDX_Net_Models') DEMUCS_MODELS_DIR = os.path.join(MODELS_DIR, 'Demucs_Models') DEMUCS_NEWER_REPO_DIR = os.path.join(DEMUCS_MODELS_DIR, 'v3_v4_repo') +MDX_MIXER_PATH = os.path.join(BASE_PATH, 'lib_v5', 'mixer.ckpt') #Cache & Parameters VR_HASH_DIR = os.path.join(VR_MODELS_DIR, 'model_data') VR_HASH_JSON = os.path.join(VR_MODELS_DIR, 'model_data', 'model_data.json') MDX_HASH_DIR = os.path.join(MDX_MODELS_DIR, 'model_data') MDX_HASH_JSON = os.path.join(MDX_MODELS_DIR, 'model_data', 'model_data.json') +DEMUCS_MODEL_NAME_SELECT = os.path.join(DEMUCS_MODELS_DIR, 'model_data', 'model_name_mapper.json') +MDX_MODEL_NAME_SELECT = os.path.join(MDX_MODELS_DIR, 'model_data', 'model_name_mapper.json') ENSEMBLE_CACHE_DIR = os.path.join(BASE_PATH, 'gui_data', 'saved_ensembles') SETTINGS_CACHE_DIR = os.path.join(BASE_PATH, 'gui_data', 'saved_settings') VR_PARAM_DIR = os.path.join(BASE_PATH, 'lib_v5', 'vr_network', 'modelparams') @@ -244,13 +250,17 @@ class ModelData(): self.is_primary_stem_only = root.is_primary_stem_only_var.get() self.is_secondary_stem_only = root.is_secondary_stem_only_var.get() self.is_denoise = root.is_denoise_var.get() + self.mdx_batch_size = 1 if root.mdx_batch_size_var.get() == DEF_OPT else int(root.mdx_batch_size_var.get()) + self.is_mdx_ckpt = False self.wav_type_set = root.wav_type_set self.mp3_bit_set = root.mp3_bit_set_var.get() self.save_format = root.save_format_var.get() self.is_invert_spec = root.is_invert_spec_var.get() + self.is_mixer_mode = root.is_mixer_mode_var.get() self.demucs_stems = root.demucs_stems_var.get() self.demucs_source_list = [] self.demucs_stem_count = 0 + self.mixer_path = MDX_MIXER_PATH self.model_name = model_name self.process_method = selected_process_method self.model_status = False if self.model_name == CHOOSE_MODEL or self.model_name == NO_MODEL else True @@ -271,6 +281,8 @@ class ModelData(): self.is_pre_proc_model = is_pre_proc_model self.is_dry_check = is_dry_check self.model_samplerate = 44100 + self.model_capacity = 32, 128 + self.is_vr_51_model = False self.is_demucs_pre_proc_model_inst_mix = False self.manual_download_Button = None self.secondary_model_4_stem = [] @@ -301,27 +313,31 @@ class ModelData(): self.is_tta = root.is_tta_var.get() self.is_post_process = root.is_post_process_var.get() self.window_size = int(root.window_size_var.get()) - self.batch_size = int(root.batch_size_var.get()) + self.batch_size = 1 if root.batch_size_var.get() == DEF_OPT else int(root.batch_size_var.get()) self.crop_size = int(root.crop_size_var.get()) self.is_high_end_process = 'mirroring' if root.is_high_end_process_var.get() else 'None' self.post_process_threshold = float(root.post_process_threshold_var.get()) + self.model_capacity = 32, 128 self.model_path = os.path.join(VR_MODELS_DIR, f"{self.model_name}.pth") self.get_model_hash() if self.model_hash: - self.model_data = self.get_model_data(VR_HASH_DIR, root.vr_hash_MAPPER) + self.model_data = self.get_model_data(VR_HASH_DIR, root.vr_hash_MAPPER) if not self.model_hash == WOOD_INST_MODEL_HASH else WOOD_INST_PARAMS if self.model_data: vr_model_param = os.path.join(VR_PARAM_DIR, "{}.json".format(self.model_data["vr_model_param"])) self.primary_stem = self.model_data["primary_stem"] self.secondary_stem = STEM_PAIR_MAPPER[self.primary_stem] self.vr_model_param = ModelParameters(vr_model_param) self.model_samplerate = self.vr_model_param.param['sr'] + if "nout" in self.model_data.keys() and "nout_lstm" in self.model_data.keys(): + self.model_capacity = self.model_data["nout"], self.model_data["nout_lstm"] + self.is_vr_51_model = True else: self.model_status = False if self.process_method == MDX_ARCH_TYPE: self.is_secondary_model_activated = root.mdx_is_secondary_model_activate_var.get() if not is_secondary_model else False self.margin = int(root.margin_var.get()) - self.chunks = root.determine_auto_chunks(root.chunks_var.get(), self.is_gpu_conversion) + self.chunks = root.determine_auto_chunks(root.chunks_var.get(), self.is_gpu_conversion) if root.is_chunk_mdxnet_var.get() else 0 self.get_mdx_model_path() self.get_model_hash() if self.model_hash: @@ -392,12 +408,19 @@ class ModelData(): def get_mdx_model_path(self): - for file_name, chosen_mdx_model in MDX_NAME_SELECT.items(): + if self.model_name.endswith(CKPT): + # self.chunks = 0 + # self.is_mdx_batch_mode = True + self.is_mdx_ckpt = True + + ext = '' if self.is_mdx_ckpt else ONNX + + for file_name, chosen_mdx_model in root.mdx_name_select_MAPPER.items(): if self.model_name in chosen_mdx_model: - self.model_path = os.path.join(MDX_MODELS_DIR, f"{file_name}.onnx") + self.model_path = os.path.join(MDX_MODELS_DIR, f"{file_name}{ext}") break else: - self.model_path = os.path.join(MDX_MODELS_DIR, f"{self.model_name}.onnx") + self.model_path = os.path.join(MDX_MODELS_DIR, f"{self.model_name}{ext}") self.mixer_path = os.path.join(MDX_MODELS_DIR, f"mixer_val.ckpt") @@ -406,7 +429,7 @@ class ModelData(): demucs_newer = [True for x in DEMUCS_NEWER_TAGS if x in self.model_name] demucs_model_dir = DEMUCS_NEWER_REPO_DIR if demucs_newer else DEMUCS_MODELS_DIR - for file_name, chosen_model in DEMUCS_NAME_SELECT.items(): + for file_name, chosen_model in root.demucs_name_select_MAPPER.items(): if self.model_name in chosen_model: self.model_path = os.path.join(demucs_model_dir, file_name) break @@ -475,10 +498,13 @@ class ModelData(): break if not self.model_hash: - with open(self.model_path, 'rb') as f: - f.seek(- 10000 * 1024, 2) - self.model_hash = hashlib.md5(f.read()).hexdigest() - + try: + with open(self.model_path, 'rb') as f: + f.seek(- 10000 * 1024, 2) + self.model_hash = hashlib.md5(f.read()).hexdigest() + except: + self.model_hash = hashlib.md5(open(self.model_path,'rb').read()).hexdigest() + table_entry = {self.model_path: self.model_hash} model_hash_table.update(table_entry) @@ -523,6 +549,7 @@ class Ensembler(): stem_outputs = self.get_files_to_ensemble(folder=export_path, prefix=audio_file_base, suffix=f"_({stem_tag}).wav") audio_file_output = f"{self.is_testing_audio}{audio_file_base}{self.chosen_ensemble}_({stem_tag})" stem_save_path = os.path.join('{}'.format(self.main_export_path),'{}.wav'.format(audio_file_output)) + if stem_outputs: spec_utils.ensemble_inputs(stem_outputs, algorithm, self.is_normalization, self.wav_type_set, stem_save_path) save_format(stem_save_path, self.save_format, self.mp3_bit_set) @@ -537,11 +564,29 @@ class Ensembler(): except Exception as e: print(e) - def ensemble_manual(self, audio_inputs, audio_file_base): + def ensemble_manual(self, audio_inputs, audio_file_base, is_bulk=False): """Processes the given outputs and ensembles them with the chosen algorithm""" + is_mv_sep = True + + if is_bulk: + number_list = list(set([os.path.basename(i).split("_")[0] for i in audio_inputs])) + for n in number_list: + current_list = [i for i in audio_inputs if os.path.basename(i).startswith(n)] + audio_file_base = os.path.basename(current_list[0]).split('.wav')[0] + stem_testing = "instrum" if "Instrumental" in audio_file_base else "vocals" + if is_mv_sep: + audio_file_base = audio_file_base.split("_") + audio_file_base = f"{audio_file_base[1]}_{audio_file_base[2]}_{stem_testing}" + self.ensemble_manual_process(current_list, audio_file_base, is_bulk) + else: + self.ensemble_manual_process(audio_inputs, audio_file_base, is_bulk) + + def ensemble_manual_process(self, audio_inputs, audio_file_base, is_bulk): + algorithm = root.choose_algorithm_var.get() - stem_save_path = os.path.join('{}'.format(self.main_export_path),'{}{}_({}).wav'.format(self.is_testing_audio, audio_file_base, algorithm)) + algorithm_text = "" if is_bulk else f"_({root.choose_algorithm_var.get()})" + stem_save_path = os.path.join('{}'.format(self.main_export_path),'{}{}{}.wav'.format(self.is_testing_audio, audio_file_base, algorithm_text)) spec_utils.ensemble_inputs(audio_inputs, algorithm, self.is_normalization, self.wav_type_set, stem_save_path) save_format(stem_save_path, self.save_format, self.mp3_bit_set) @@ -597,7 +642,7 @@ class ToolTip(object): tw.wm_overrideredirect(1) tw.wm_geometry("+%d+%d" % (x, y)) label = Label(tw, text=self.text, justify=LEFT, - background="#151515", foreground="#dedede", highlightcolor="#898b8e", relief=SOLID, borderwidth=1, + background="#333333", foreground="#ffffff", highlightcolor="#898b8e", relief=SOLID, borderwidth=1, font=(MAIN_FONT_NAME, f"{FONT_SIZE_1}", "normal"))#('Century Gothic', FONT_SIZE_4) label.pack(ipadx=1) @@ -797,6 +842,8 @@ class MainWindow(TkinterDnD.Tk if is_dnd_compatible else tk.Tk): self.clear_cache_torch = False self.vr_hash_MAPPER = load_model_hash_data(VR_HASH_JSON) self.mdx_hash_MAPPER = load_model_hash_data(MDX_HASH_JSON) + self.mdx_name_select_MAPPER = load_model_hash_data(MDX_MODEL_NAME_SELECT) + self.demucs_name_select_MAPPER = load_model_hash_data(DEMUCS_MODEL_NAME_SELECT) self.is_gpu_available = torch.cuda.is_available() if not OPERATING_SYSTEM == 'Darwin' else torch.backends.mps.is_available() self.is_process_stopped = False self.inputs_from_dir = [] @@ -968,7 +1015,7 @@ class MainWindow(TkinterDnD.Tk if is_dnd_compatible else tk.Tk): if arch_type == MDX_ARCH_TYPE: model_data: List[ModelData] = [ModelData(model, MDX_ARCH_TYPE)] if arch_type == DEMUCS_ARCH_TYPE: - model_data: List[ModelData] = [ModelData(model, DEMUCS_ARCH_TYPE)] + model_data: List[ModelData] = [ModelData(model, DEMUCS_ARCH_TYPE)]# return model_data @@ -979,7 +1026,7 @@ class MainWindow(TkinterDnD.Tk if is_dnd_compatible else tk.Tk): if network == MDX_ARCH_TYPE: dir = MDX_HASH_DIR - [os.remove(os.path.join(dir, x)) for x in os.listdir(dir) if x not in 'model_data.json'] + [os.remove(os.path.join(dir, x)) for x in os.listdir(dir) if x not in ['model_data.json', 'model_name_mapper.json']] self.vr_model_var.set(CHOOSE_MODEL) self.mdx_net_model_var.set(CHOOSE_MODEL) self.model_data_table.clear() @@ -1113,21 +1160,21 @@ class MainWindow(TkinterDnD.Tk if is_dnd_compatible else tk.Tk): self.mdx_net_model_Option_place = lambda:self.mdx_net_model_Option.place(x=0, y=LOW_MENU_Y[1], width=LEFT_ROW_WIDTH, height=OPTION_HEIGHT, relx=0, rely=7/self.COL1_ROWS, relwidth=1/3, relheight=1/self.COL1_ROWS) self.help_hints(self.mdx_net_model_Label, text=CHOOSE_MODEL_HELP) - # MDX-chunks - self.chunks_Label = self.main_window_LABEL_SET(self.options_Frame, CHUNKS_MDX_MAIN_LABEL) - self.chunks_Label_place = lambda:self.chunks_Label.place(x=MAIN_ROW_X[0], y=MAIN_ROW_Y[0], width=0, height=LABEL_HEIGHT, relx=1/3, rely=2/self.COL1_ROWS, relwidth=1/3, relheight=1/self.COL2_ROWS) - self.chunks_Option = ttk.Combobox(self.options_Frame, value=CHUNKS, textvariable=self.chunks_var) - self.chunks_Option_place = lambda:self.chunks_Option.place(x=MAIN_ROW_X[1], y=MAIN_ROW_Y[1], width=MAIN_ROW_WIDTH, height=OPTION_HEIGHT, relx=1/3, rely=3/self.COL1_ROWS, relwidth=1/3, relheight=1/self.COL2_ROWS) - self.combobox_entry_validation(self.chunks_Option, self.chunks_var, REG_CHUNKS, CHUNKS) - self.help_hints(self.chunks_Label, text=CHUNKS_HELP) - - # MDX-Margin - self.margin_Label = self.main_window_LABEL_SET(self.options_Frame, MARGIN_MDX_MAIN_LABEL) - self.margin_Label_place = lambda:self.margin_Label.place(x=MAIN_ROW_2_X[0], y=MAIN_ROW_2_Y[0], width=0, height=LABEL_HEIGHT, relx=2/3, rely=2/self.COL1_ROWS, relwidth=1/3, relheight=1/self.COL2_ROWS) - self.margin_Option = ttk.Combobox(self.options_Frame, value=MARGIN_SIZE, textvariable=self.margin_var) - self.margin_Option_place = lambda:self.margin_Option.place(x=MAIN_ROW_2_X[1], y=MAIN_ROW_2_Y[1], width=MAIN_ROW_WIDTH, height=OPTION_HEIGHT, relx=2/3, rely=3/self.COL1_ROWS, relwidth=1/3, relheight=1/self.COL2_ROWS) - self.combobox_entry_validation(self.margin_Option, self.margin_var, REG_MARGIN, MARGIN_SIZE) - self.help_hints(self.margin_Label, text=MARGIN_HELP) + # MDX-Batches + self.mdx_batch_size_Label = self.main_window_LABEL_SET(self.options_Frame, BATCHES_MDX_MAIN_LABEL) + self.mdx_batch_size_Label_place = lambda:self.mdx_batch_size_Label.place(x=MAIN_ROW_X[0], y=MAIN_ROW_Y[0], width=0, height=LABEL_HEIGHT, relx=1/3, rely=2/self.COL1_ROWS, relwidth=1/3, relheight=1/self.COL2_ROWS) + self.mdx_batch_size_Option = ttk.Combobox(self.options_Frame, value=BATCH_SIZE, width=MENU_COMBOBOX_WIDTH, textvariable=self.mdx_batch_size_var) + self.mdx_batch_size_Option_place = lambda:self.mdx_batch_size_Option.place(x=MAIN_ROW_X[1], y=MAIN_ROW_Y[1], width=MAIN_ROW_WIDTH, height=OPTION_HEIGHT, relx=1/3, rely=3/self.COL1_ROWS, relwidth=1/3, relheight=1/self.COL2_ROWS) + self.combobox_entry_validation(self.mdx_batch_size_Option, self.mdx_batch_size_var, REG_BATCHES, BATCH_SIZE) + self.help_hints(self.mdx_batch_size_Label, text=BATCH_SIZE_HELP) + + # MDX-Volume Compensation + self.compensate_Label = self.main_window_LABEL_SET(self.options_Frame, VOL_COMP_MDX_MAIN_LABEL) + self.compensate_Label_place = lambda:self.compensate_Label.place(x=MAIN_ROW_2_X[0], y=MAIN_ROW_2_Y[0], width=0, height=LABEL_HEIGHT, relx=2/3, rely=2/self.COL1_ROWS, relwidth=1/3, relheight=1/self.COL2_ROWS) + self.compensate_Option = ttk.Combobox(self.options_Frame, value=VOL_COMPENSATION, width=MENU_COMBOBOX_WIDTH, textvariable=self.compensate_var) + self.compensate_Option_place = lambda:self.compensate_Option.place(x=MAIN_ROW_2_X[1], y=MAIN_ROW_2_Y[1], width=MAIN_ROW_WIDTH, height=OPTION_HEIGHT, relx=2/3, rely=3/self.COL1_ROWS, relwidth=1/3, relheight=1/self.COL2_ROWS) + self.combobox_entry_validation(self.compensate_Option, self.compensate_var, REG_COMPENSATION, VOL_COMPENSATION) + self.help_hints(self.compensate_Label, text=COMPENSATE_HELP) ### VR ARCH ### @@ -1301,10 +1348,10 @@ class MainWindow(TkinterDnD.Tk if is_dnd_compatible else tk.Tk): self.segment_Option, self.mdx_net_model_Label, self.mdx_net_model_Option, - self.chunks_Label, - self.chunks_Option, - self.margin_Label, - self.margin_Option, + self.mdx_batch_size_Label, + self.mdx_batch_size_Option, + self.compensate_Label, + self.compensate_Option, self.chosen_ensemble_Label, self.chosen_ensemble_Option, self.save_current_settings_Label, @@ -1337,6 +1384,7 @@ class MainWindow(TkinterDnD.Tk if is_dnd_compatible else tk.Tk): self.demucs_model_var, self.demucs_stems_var, self.is_chunk_demucs_var, + self.is_chunk_mdxnet_var, self.is_primary_stem_only_Demucs_var, self.is_secondary_stem_only_Demucs_var, self.is_primary_stem_only_var, @@ -1361,12 +1409,24 @@ class MainWindow(TkinterDnD.Tk if is_dnd_compatible else tk.Tk): invalid = lambda:(var.set(default[0])) combobox.config(validate='focus', validatecommand=(self.register(validation), '%P'), invalidcommand=(self.register(invalid),)) + def combo_box_selection_clear(self): + for option in self.options_Frame.winfo_children(): + if type(option) is ttk.Combobox: + option.selection_clear() + def bind_widgets(self): """Bind widgets to the drag & drop mechanic""" - self.chosen_audio_tool_align = tk.BooleanVar(value=True) - add_align = lambda e:(self.chosen_audio_tool_Option['menu'].add_radiobutton(label=ALIGN_INPUTS, command=tk._setit(self.chosen_audio_tool_var, ALIGN_INPUTS)), self.chosen_audio_tool_align.set(False)) if self.chosen_audio_tool_align else None - + self.chosen_audio_tool_align = tk.BooleanVar(value=True) + other_items = [self.options_Frame, self.filePaths_Frame, self.title_Label, self.progressbar, self.conversion_Button, self.settings_Button, self.stop_Button, self.command_Text] + all_widgets = self.options_Frame.winfo_children() + self.filePaths_Frame.winfo_children() + other_items + + for option in all_widgets: + if type(option) is ttk.Combobox: + option.bind("", lambda e:option.selection_clear()) + else: + option.bind('', lambda e:(option.focus(), self.combo_box_selection_clear())) + if is_dnd_compatible: self.filePaths_saveTo_Button.drop_target_register(DND_FILES) self.filePaths_saveTo_Entry.drop_target_register(DND_FILES) @@ -1376,10 +1436,10 @@ class MainWindow(TkinterDnD.Tk if is_dnd_compatible else tk.Tk): self.filePaths_saveTo_Entry.dnd_bind('<>', lambda e: drop(e, accept_mode='folder')) self.ensemble_listbox_Option.bind('<>', lambda e: self.chosen_ensemble_var.set(CHOOSE_ENSEMBLE_OPTION)) - self.options_Frame.bind(right_click_button, lambda e:self.right_click_menu_popup(e, main_menu=True)) - self.filePaths_musicFile_Entry.bind(right_click_button, lambda e:self.input_right_click_menu(e)) - self.filePaths_musicFile_Entry.bind('', lambda e:self.check_is_open_menu_view_inputs()) - + self.options_Frame.bind(right_click_button, lambda e:(self.right_click_menu_popup(e, main_menu=True), self.options_Frame.focus())) + self.filePaths_musicFile_Entry.bind(right_click_button, lambda e:(self.input_right_click_menu(e), self.filePaths_musicFile_Entry.focus())) + self.filePaths_musicFile_Entry.bind('', lambda e:(self.check_is_open_menu_view_inputs(), self.filePaths_musicFile_Entry.focus())) + #--Input/Export Methods-- def input_select_filedialog(self): @@ -1472,10 +1532,12 @@ class MainWindow(TkinterDnD.Tk if is_dnd_compatible else tk.Tk): except Exception as e: self.error_log_var.set(error_text('Temp File Deletion', e)) - def get_files_from_dir(self, directory, ext): + def get_files_from_dir(self, directory, ext, is_mdxnet=False): """Gets files from specified directory that ends with specified extention""" - return tuple(os.path.splitext(x)[0] for x in os.listdir(directory) if x.endswith(ext)) + #ext = '.onnx' if is_mdxnet else ext + + return tuple(x if is_mdxnet and x.endswith(CKPT) else os.path.splitext(x)[0] for x in os.listdir(directory) if x.endswith(ext)) def determine_auto_chunks(self, chunks, gpu): """Determines appropriate chunk size based on user computer specs""" @@ -1483,6 +1545,10 @@ class MainWindow(TkinterDnD.Tk if is_dnd_compatible else tk.Tk): if OPERATING_SYSTEM == 'Darwin': gpu = -1 + if chunks == BATCH_MODE: + chunks = 0 + #self.chunks_var.set(AUTO_SELECT) + if chunks == 'Full': chunk_set = 0 elif chunks == 'Auto': @@ -2355,27 +2421,16 @@ class MainWindow(TkinterDnD.Tk if is_dnd_compatible else tk.Tk): aggression_setting_Label = self.menu_sub_LABEL_SET(vr_opt_frame, 'Aggression Setting') aggression_setting_Label.grid(row=3,column=0,padx=0,pady=5) - aggression_setting_Option = ttk.Combobox(vr_opt_frame, value=VR_BATCH, width=MENU_COMBOBOX_WIDTH, textvariable=self.aggression_setting_var) + aggression_setting_Option = ttk.Combobox(vr_opt_frame, value=VR_AGGRESSION, width=MENU_COMBOBOX_WIDTH, textvariable=self.aggression_setting_var) aggression_setting_Option.grid(row=4,column=0,padx=0,pady=5) - self.combobox_entry_validation(aggression_setting_Option, self.aggression_setting_var, REG_WINDOW, VR_BATCH) + self.combobox_entry_validation(aggression_setting_Option, self.aggression_setting_var, REG_WINDOW, ['10']) self.help_hints(aggression_setting_Label, text=AGGRESSION_SETTING_HELP) - self.crop_size_Label = self.menu_sub_LABEL_SET(vr_opt_frame, 'Crop Size') - self.crop_size_Label.grid(row=5,column=0,padx=0,pady=5) - self.crop_size_sub_Label = self.menu_sub_LABEL_SET(vr_opt_frame, '(Works with select models only)', font_size=FONT_SIZE_1) - self.crop_size_sub_Label.grid(row=6,column=0,padx=0,pady=0) - self.crop_size_Option = ttk.Combobox(vr_opt_frame, value=VR_CROP, width=MENU_COMBOBOX_WIDTH, textvariable=self.crop_size_var) - self.crop_size_Option.grid(row=7,column=0,padx=0,pady=5) - self.combobox_entry_validation(self.crop_size_Option, self.crop_size_var, REG_WINDOW, VR_CROP) - self.help_hints(self.crop_size_Label, text=CROP_SIZE_HELP) - self.batch_size_Label = self.menu_sub_LABEL_SET(vr_opt_frame, 'Batch Size') self.batch_size_Label.grid(row=8,column=0,padx=0,pady=5) - self.batch_size_sub_Label = self.menu_sub_LABEL_SET(vr_opt_frame, '(Works with select models only)', font_size=FONT_SIZE_1) - self.batch_size_sub_Label.grid(row=9,column=0,padx=0,pady=0) - self.batch_size_Option = ttk.Combobox(vr_opt_frame, value=VR_BATCH, width=MENU_COMBOBOX_WIDTH, textvariable=self.batch_size_var) + self.batch_size_Option = ttk.Combobox(vr_opt_frame, value=BATCH_SIZE, width=MENU_COMBOBOX_WIDTH, textvariable=self.batch_size_var) self.batch_size_Option.grid(row=10,column=0,padx=0,pady=5) - self.combobox_entry_validation(self.batch_size_Option, self.batch_size_var, REG_WINDOW, VR_BATCH) + self.combobox_entry_validation(self.batch_size_Option, self.batch_size_var, REG_BATCHES, BATCH_SIZE) self.help_hints(self.batch_size_Label, text=BATCH_SIZE_HELP) self.post_process_threshold_Label = self.menu_sub_LABEL_SET(vr_opt_frame, 'Post-process Threshold') @@ -2468,14 +2523,14 @@ class MainWindow(TkinterDnD.Tk if is_dnd_compatible else tk.Tk): self.chunks_demucs_Label.grid(row=7,column=0,padx=0,pady=5) self.chunks_demucs_Option = ttk.Combobox(demucs_frame, value=CHUNKS, width=MENU_COMBOBOX_WIDTH, textvariable=self.chunks_demucs_var) self.chunks_demucs_Option.grid(row=8,column=0,padx=0,pady=5) - self.combobox_entry_validation(self.chunks_demucs_Option, self.chunks_demucs_var, REG_CHUNKS, CHUNKS) - self.help_hints(self.chunks_demucs_Label, text=CHUNKS_HELP) + self.combobox_entry_validation(self.chunks_demucs_Option, self.chunks_demucs_var, REG_CHUNKS_DEMUCS, CHUNKS) + self.help_hints(self.chunks_demucs_Label, text=CHUNKS_DEMUCS_HELP) self.margin_demucs_Label = self.menu_sub_LABEL_SET(demucs_frame, 'Chunk Margin') self.margin_demucs_Label.grid(row=9,column=0,padx=0,pady=5) self.margin_demucs_Option = ttk.Combobox(demucs_frame, value=MARGIN_SIZE, width=MENU_COMBOBOX_WIDTH, textvariable=self.margin_demucs_var) self.margin_demucs_Option.grid(row=10,column=0,padx=0,pady=5) - self.combobox_entry_validation(self.margin_Option, self.margin_demucs_var, REG_MARGIN, MARGIN_SIZE) + self.combobox_entry_validation(self.margin_demucs_Option, self.margin_demucs_var, REG_MARGIN, MARGIN_SIZE) self.help_hints(self.margin_demucs_Label, text=MARGIN_HELP) self.is_chunk_demucs_Option = ttk.Checkbutton(demucs_frame, text='Enable Chunks', width=DEMUCS_CHECKBOXS_WIDTH, variable=self.is_chunk_demucs_var, command=chunks_toggle) @@ -2494,14 +2549,18 @@ class MainWindow(TkinterDnD.Tk if is_dnd_compatible else tk.Tk): is_invert_spec_Option.grid(row=14,column=0,padx=0,pady=0) self.help_hints(is_invert_spec_Option, text=IS_INVERT_SPEC_HELP) + is_mixer_mode_Option = ttk.Checkbutton(demucs_frame, text='Mixer Mode', width=DEMUCS_CHECKBOXS_WIDTH, variable=self.is_mixer_mode_var) + is_mixer_mode_Option.grid(row=15,column=0,padx=0,pady=0) + self.help_hints(is_mixer_mode_Option, text=IS_MIXER_MODE_HELP) + self.open_demucs_model_dir_Button = ttk.Button(demucs_frame, text='Open Demucs Model Folder', command=lambda:OPEN_FILE_func(DEMUCS_MODELS_DIR)) - self.open_demucs_model_dir_Button.grid(row=15,column=0,padx=0,pady=5) + self.open_demucs_model_dir_Button.grid(row=16,column=0,padx=0,pady=5) self.demucs_return_Button = ttk.Button(demucs_frame, text=BACK_TO_MAIN_MENU, command=lambda:(self.menu_advanced_demucs_options_close_window(), self.check_is_menu_settings_open())) - self.demucs_return_Button.grid(row=16,column=0,padx=0,pady=5) + self.demucs_return_Button.grid(row=17,column=0,padx=0,pady=5) self.demucs_close_Button = ttk.Button(demucs_frame, text='Close Window', command=lambda:self.menu_advanced_demucs_options_close_window()) - self.demucs_close_Button.grid(row=17,column=0,padx=0,pady=5) + self.demucs_close_Button.grid(row=18,column=0,padx=0,pady=5) demucs_pre_proc_model_title_Label = self.menu_title_LABEL_SET(demucs_pre_model_frame, "Pre-process Model") demucs_pre_proc_model_title_Label.grid(row=0,column=0,padx=0,pady=15) @@ -2536,6 +2595,10 @@ class MainWindow(TkinterDnD.Tk if is_dnd_compatible else tk.Tk): tab1 = self.menu_tab_control(mdx_net_opt, self.mdx_secondary_model_vars) + enable_chunks = lambda:(margin_Option.configure(state=tk.NORMAL), chunks_Option.configure(state=tk.NORMAL)) + disable_chunks = lambda:(margin_Option.configure(state=tk.DISABLED), chunks_Option.configure(state=tk.DISABLED)) + chunks_toggle = lambda:enable_chunks() if self.is_chunk_mdxnet_var.get() else disable_chunks() + mdx_net_frame = self.menu_FRAME_SET(tab1) mdx_net_frame.grid(row=0,column=0,padx=0,pady=0) @@ -2543,47 +2606,60 @@ class MainWindow(TkinterDnD.Tk if is_dnd_compatible else tk.Tk): mdx_opt_title.grid(row=0,column=0,padx=0,pady=10) if not self.chosen_process_method_var.get() == MDX_ARCH_TYPE: - chunks_Label = self.menu_sub_LABEL_SET(mdx_net_frame, 'Chunks') - chunks_Label.grid(row=1,column=0,padx=0,pady=5) - chunks_Option = ttk.Combobox(mdx_net_frame, value=CHUNKS, width=MENU_COMBOBOX_WIDTH, textvariable=self.chunks_var) - chunks_Option.grid(row=2,column=0,padx=0,pady=5) - self.combobox_entry_validation(chunks_Option, self.chunks_var, REG_CHUNKS, CHUNKS) - self.help_hints(chunks_Label, text=CHUNKS_HELP) + mdx_batch_size_Label = self.menu_sub_LABEL_SET(mdx_net_frame, 'Batch Size') + mdx_batch_size_Label.grid(row=5,column=0,padx=0,pady=5) + mdx_batch_size_Option = ttk.Combobox(mdx_net_frame, value=BATCH_SIZE, width=MENU_COMBOBOX_WIDTH, textvariable=self.mdx_batch_size_var) + mdx_batch_size_Option.grid(row=6,column=0,padx=0,pady=5) + self.combobox_entry_validation(mdx_batch_size_Option, self.mdx_batch_size_var, REG_SHIFTS, BATCH_SIZE) + self.help_hints(mdx_batch_size_Label, text=BATCH_SIZE_HELP) + + compensate_Label = self.menu_sub_LABEL_SET(mdx_net_frame, 'Volume Compensation') + compensate_Label.grid(row=7,column=0,padx=0,pady=5) + compensate_Option = ttk.Combobox(mdx_net_frame, value=VOL_COMPENSATION, width=MENU_COMBOBOX_WIDTH, textvariable=self.compensate_var) + compensate_Option.grid(row=8,column=0,padx=0,pady=5) + self.combobox_entry_validation(compensate_Option, self.compensate_var, REG_COMPENSATION, VOL_COMPENSATION) + self.help_hints(compensate_Label, text=COMPENSATE_HELP) + + chunks_Label = self.menu_sub_LABEL_SET(mdx_net_frame, 'Chunks') + chunks_Label.grid(row=1,column=0,padx=0,pady=5) + chunks_Option = ttk.Combobox(mdx_net_frame, value=CHUNKS, width=MENU_COMBOBOX_WIDTH, textvariable=self.chunks_var) + chunks_Option.grid(row=2,column=0,padx=0,pady=5) + self.combobox_entry_validation(chunks_Option, self.chunks_var, REG_CHUNKS, CHUNKS) + self.help_hints(chunks_Label, text=CHUNKS_HELP) + + margin_Label = self.menu_sub_LABEL_SET(mdx_net_frame, 'Chunk Margin') + margin_Label.grid(row=3,column=0,padx=0,pady=5) + margin_Option = ttk.Combobox(mdx_net_frame, value=MARGIN_SIZE, width=MENU_COMBOBOX_WIDTH, textvariable=self.margin_var) + margin_Option.grid(row=4,column=0,padx=0,pady=5) + self.combobox_entry_validation(margin_Option, self.margin_var, REG_MARGIN, MARGIN_SIZE) + self.help_hints(margin_Label, text=MARGIN_HELP) - margin_Label = self.menu_sub_LABEL_SET(mdx_net_frame, 'Chunk Margin') - margin_Label.grid(row=3,column=0,padx=0,pady=5) - margin_Option = ttk.Combobox(mdx_net_frame, value=MARGIN_SIZE, width=MENU_COMBOBOX_WIDTH, textvariable=self.margin_var) - margin_Option.grid(row=4,column=0,padx=0,pady=5) - self.combobox_entry_validation(margin_Option, self.margin_var, REG_MARGIN, MARGIN_SIZE) - self.help_hints(margin_Label, text=MARGIN_HELP) - - compensate_Label = self.menu_sub_LABEL_SET(mdx_net_frame, 'Volume Compensation') - compensate_Label.grid(row=5,column=0,padx=0,pady=5) - compensate_Option = ttk.Combobox(mdx_net_frame, value=VOL_COMPENSATION, width=MENU_COMBOBOX_WIDTH, textvariable=self.compensate_var) - compensate_Option.grid(row=6,column=0,padx=0,pady=5) - self.combobox_entry_validation(compensate_Option, self.compensate_var, REG_COMPENSATION, VOL_COMPENSATION) - self.help_hints(compensate_Label, text=COMPENSATE_HELP) + is_chunk_mdxnet_Option = ttk.Checkbutton(mdx_net_frame, text='Enable Chunks', width=MDX_CHECKBOXS_WIDTH, variable=self.is_chunk_mdxnet_var, command=chunks_toggle) + is_chunk_mdxnet_Option.grid(row=10,column=0,padx=0,pady=0) + self.help_hints(is_chunk_mdxnet_Option, text=IS_CHUNK_MDX_NET_HELP) is_denoise_Option = ttk.Checkbutton(mdx_net_frame, text='Denoise Output', width=MDX_CHECKBOXS_WIDTH, variable=self.is_denoise_var) - is_denoise_Option.grid(row=8,column=0,padx=0,pady=0) + is_denoise_Option.grid(row=11,column=0,padx=0,pady=0) self.help_hints(is_denoise_Option, text=IS_DENOISE_HELP) is_invert_spec_Option = ttk.Checkbutton(mdx_net_frame, text='Spectral Inversion', width=MDX_CHECKBOXS_WIDTH, variable=self.is_invert_spec_var) - is_invert_spec_Option.grid(row=9,column=0,padx=0,pady=0) + is_invert_spec_Option.grid(row=12,column=0,padx=0,pady=0) self.help_hints(is_invert_spec_Option, text=IS_INVERT_SPEC_HELP) clear_mdx_cache_Button = ttk.Button(mdx_net_frame, text='Clear Auto-Set Cache', command=lambda:self.clear_cache(MDX_ARCH_TYPE)) - clear_mdx_cache_Button.grid(row=10,column=0,padx=0,pady=5) + clear_mdx_cache_Button.grid(row=13,column=0,padx=0,pady=5) self.help_hints(clear_mdx_cache_Button, text=CLEAR_CACHE_HELP) open_mdx_model_dir_Button = ttk.Button(mdx_net_frame, text='Open MDX-Net Models Folder', command=lambda:OPEN_FILE_func(MDX_MODELS_DIR)) - open_mdx_model_dir_Button.grid(row=11,column=0,padx=0,pady=5) + open_mdx_model_dir_Button.grid(row=14,column=0,padx=0,pady=5) mdx_return_Button = ttk.Button(mdx_net_frame, text=BACK_TO_MAIN_MENU, command=lambda:(self.menu_advanced_mdx_options_close_window(), self.check_is_menu_settings_open())) - mdx_return_Button.grid(row=12,column=0,padx=0,pady=5) + mdx_return_Button.grid(row=15,column=0,padx=0,pady=5) mdx_close_Button = ttk.Button(mdx_net_frame, text='Close Window', command=lambda:self.menu_advanced_mdx_options_close_window()) - mdx_close_Button.grid(row=13,column=0,padx=0,pady=5) + mdx_close_Button.grid(row=16,column=0,padx=0,pady=5) + + chunks_toggle() self.menu_placement(mdx_net_opt, "Advanced MDX-Net Options", is_help_hints=True, close_function=self.menu_advanced_mdx_options_close_window) @@ -3197,26 +3273,44 @@ class MainWindow(TkinterDnD.Tk if is_dnd_compatible else tk.Tk): def pop_up_mdx_model(self, mdx_model_hash, model_path): """Opens MDX-Net model settings""" - is_onnx_model = True + is_compatible_model = True + is_ckpt = False + primary_stem = VOCAL_STEM try: - model = onnx.load(model_path) - model_shapes = [[d.dim_value for d in _input.type.tensor_type.shape.dim] for _input in model.graph.input][0] - dim_f = model_shapes[2] - dim_t = int(math.log(model_shapes[3], 2)) + if model_path.endswith(ONNX): + model = onnx.load(model_path) + model_shapes = [[d.dim_value for d in _input.type.tensor_type.shape.dim] for _input in model.graph.input][0] + dim_f = model_shapes[2] + dim_t = int(math.log(model_shapes[3], 2)) + n_fft = '6144' + + if model_path.endswith(CKPT): + is_ckpt = True + model_params = torch.load(model_path, map_location=lambda storage, loc: storage)['hyper_parameters'] + print('model_params: ', model_params) + dim_f = model_params['dim_f'] + dim_t = int(math.log(model_params['dim_t'], 2)) + n_fft = model_params['n_fft'] + + for stem in STEM_SET_MENU: + if model_params['target_name'] == stem.lower(): + primary_stem = INST_STEM if model_params['target_name'] == OTHER_STEM.lower() else stem + except Exception as e: dim_f = 0 dim_t = 0 self.error_dialoge(INVALID_ONNX_MODEL_ERROR) self.error_log_var.set("{}".format(error_text('MDX-Net Model Settings', e))) - is_onnx_model = False - - if is_onnx_model: + is_compatible_model = False + self.mdx_model_params = None + + if is_compatible_model: mdx_model_set = Toplevel(root) - mdx_n_fft_scale_set_var = tk.StringVar(value='6144') + mdx_n_fft_scale_set_var = tk.StringVar(value=n_fft) mdx_dim_f_set_var = tk.StringVar(value=dim_f) mdx_dim_t_set_var = tk.StringVar(value=dim_t) - primary_stem_var = tk.StringVar(value='Vocals') + primary_stem_var = tk.StringVar(value=primary_stem) mdx_compensate_var = tk.StringVar(value=1.035) mdx_model_set_Frame = self.menu_FRAME_SET(mdx_model_set) @@ -3270,6 +3364,11 @@ class MainWindow(TkinterDnD.Tk if is_dnd_compatible else tk.Tk): stop_process_Button = ttk.Button(mdx_model_set_Frame, text="Cancel", command=lambda:cancel()) stop_process_Button.grid(row=17,column=0,padx=0,pady=0) + if is_ckpt: + mdx_dim_t_set_Option.configure(state=DISABLED) + mdx_dim_f_set_Option.configure(state=DISABLED) + mdx_n_fft_scale_set_Option.configure(state=DISABLED) + def pull_data(): mdx_model_params = { 'compensate': float(mdx_compensate_var.get()), @@ -3304,7 +3403,7 @@ class MainWindow(TkinterDnD.Tk if is_dnd_compatible else tk.Tk): vr_param_menu = Toplevel() get_vr_params = lambda dir, ext:tuple(os.path.splitext(x)[0] for x in os.listdir(dir) if x.endswith(ext)) - new_vr_params = get_vr_params(VR_PARAM_DIR, '.json') + new_vr_params = get_vr_params(VR_PARAM_DIR, JSON) vr_model_param_var = tk.StringVar(value='None Selected') vr_model_stem_var = tk.StringVar(value='Vocals') @@ -3474,7 +3573,6 @@ class MainWindow(TkinterDnD.Tk if is_dnd_compatible else tk.Tk): if user_refresh: self.download_list_state() - self.download_list_fill() for widget in self.download_center_Buttons:widget.configure(state=tk.NORMAL) if refresh_list_Button: @@ -3491,8 +3589,15 @@ class MainWindow(TkinterDnD.Tk if is_dnd_compatible else tk.Tk): self.app_update_status_Text_var.set('UVR Version Current') else: is_new_update = True - self.app_update_status_Text_var.set(f"Update Found: {self.lastest_version}") - self.app_update_button_Text_var.set('Click Here to Update') + is_beta_version = True if self.lastest_version == PREVIOUS_PATCH_WIN and BETA_VERSION in current_patch else False + + if is_beta_version: + self.app_update_status_Text_var.set(f"Roll Back: {self.lastest_version}") + self.app_update_button_Text_var.set('Click Here to Roll Back') + else: + self.app_update_status_Text_var.set(f"Update Found: {self.lastest_version}") + self.app_update_button_Text_var.set('Click Here to Update') + if OPERATING_SYSTEM == "Windows": self.download_update_link_var.set('{}{}{}'.format(UPDATE_REPO, self.lastest_version, application_extension)) self.download_update_path_var.set(os.path.join(BASE_PATH, f'{self.lastest_version}{application_extension}')) @@ -3502,7 +3607,8 @@ class MainWindow(TkinterDnD.Tk if is_dnd_compatible else tk.Tk): self.download_update_link_var.set(UPDATE_LINUX_REPO) if not user_refresh: - self.command_Text.write(f"\n\nNew Update Found: {self.lastest_version}\n\nClick the update button in the \"Settings\" menu to download and install!") + if not is_beta_version: + self.command_Text.write(f"\n\nNew Update Found: {self.lastest_version}\n\nClick the update button in the \"Settings\" menu to download and install!") self.download_model_settings() @@ -3602,6 +3708,8 @@ class MainWindow(TkinterDnD.Tk if is_dnd_compatible else tk.Tk): self.vr_hash_MAPPER = json.load(urllib.request.urlopen(VR_MODEL_DATA_LINK)) self.mdx_hash_MAPPER = json.load(urllib.request.urlopen(MDX_MODEL_DATA_LINK)) + self.mdx_name_select_MAPPER = json.load(urllib.request.urlopen(MDX_MODEL_NAME_DATA_LINK)) + self.demucs_name_select_MAPPER = json.load(urllib.request.urlopen(DEMUCS_MODEL_NAME_DATA_LINK)) try: vr_hash_MAPPER_dump = json.dumps(self.vr_hash_MAPPER, indent=4) @@ -3611,7 +3719,20 @@ class MainWindow(TkinterDnD.Tk if is_dnd_compatible else tk.Tk): mdx_hash_MAPPER_dump = json.dumps(self.mdx_hash_MAPPER, indent=4) with open(MDX_HASH_JSON, "w") as outfile: outfile.write(mdx_hash_MAPPER_dump) + + mdx_name_select_MAPPER_dump = json.dumps(self.mdx_name_select_MAPPER, indent=4) + with open(MDX_MODEL_NAME_SELECT, "w") as outfile: + outfile.write(mdx_name_select_MAPPER_dump) + + demucs_name_select_MAPPER_dump = json.dumps(self.demucs_name_select_MAPPER, indent=4) + with open(DEMUCS_MODEL_NAME_SELECT, "w") as outfile: + outfile.write(demucs_name_select_MAPPER_dump) + except Exception as e: + # self.vr_hash_MAPPER = load_model_hash_data(VR_HASH_JSON) + # self.mdx_hash_MAPPER = load_model_hash_data(MDX_HASH_JSON) + # self.mdx_name_select_MAPPER = load_model_hash_data(MDX_MODEL_NAME_SELECT) + # self.demucs_name_select_MAPPER = load_model_hash_data(DEMUCS_MODEL_NAME_SELECT) self.error_log_var.set(e) print(e) @@ -3813,11 +3934,11 @@ class MainWindow(TkinterDnD.Tk if is_dnd_compatible else tk.Tk): def fix_names(file, name_mapper: dict):return tuple(new_name for (old_name, new_name) in name_mapper.items() if file in old_name) - new_vr_models = self.get_files_from_dir(VR_MODELS_DIR, '.pth') - new_mdx_models = self.get_files_from_dir(MDX_MODELS_DIR, '.onnx') - new_demucs_models = self.get_files_from_dir(DEMUCS_MODELS_DIR, ('.ckpt', '.gz', '.th')) + self.get_files_from_dir(DEMUCS_NEWER_REPO_DIR, '.yaml') - new_ensembles_found = self.get_files_from_dir(ENSEMBLE_CACHE_DIR, '.json') - new_settings_found = self.get_files_from_dir(SETTINGS_CACHE_DIR, '.json') + new_vr_models = self.get_files_from_dir(VR_MODELS_DIR, PTH) + new_mdx_models = self.get_files_from_dir(MDX_MODELS_DIR, (ONNX, CKPT), is_mdxnet=True) + new_demucs_models = self.get_files_from_dir(DEMUCS_MODELS_DIR, (CKPT, '.gz', '.th')) + self.get_files_from_dir(DEMUCS_NEWER_REPO_DIR, YAML) + new_ensembles_found = self.get_files_from_dir(ENSEMBLE_CACHE_DIR, JSON) + new_settings_found = self.get_files_from_dir(SETTINGS_CACHE_DIR, JSON) new_models_found = new_vr_models + new_mdx_models + new_demucs_models is_online = self.is_online_model_menu @@ -3850,8 +3971,8 @@ class MainWindow(TkinterDnD.Tk if is_dnd_compatible else tk.Tk): self.model_data_table = [] vr_model_list = loop_directories(self.vr_model_Option, self.vr_model_var, new_vr_models, VR_ARCH_TYPE, name_mapper=None) - mdx_model_list = loop_directories(self.mdx_net_model_Option, self.mdx_net_model_var, new_mdx_models, MDX_ARCH_TYPE, name_mapper=MDX_NAME_SELECT) - demucs_model_list = loop_directories(self.demucs_model_Option, self.demucs_model_var, new_demucs_models, DEMUCS_ARCH_TYPE, name_mapper=DEMUCS_NAME_SELECT) + mdx_model_list = loop_directories(self.mdx_net_model_Option, self.mdx_net_model_var, new_mdx_models, MDX_ARCH_TYPE, name_mapper=self.mdx_name_select_MAPPER) + demucs_model_list = loop_directories(self.demucs_model_Option, self.demucs_model_var, new_demucs_models, DEMUCS_ARCH_TYPE, name_mapper=self.demucs_name_select_MAPPER) self.ensemble_model_list = vr_model_list + mdx_model_list + demucs_model_list self.last_found_models = new_models_found @@ -3903,10 +4024,10 @@ class MainWindow(TkinterDnD.Tk if is_dnd_compatible else tk.Tk): if self.chosen_process_method_var.get() == MDX_ARCH_TYPE: self.mdx_net_model_Label_place() self.mdx_net_model_Option_place() - self.chunks_Label_place() - self.chunks_Option_place() - self.margin_Label_place() - self.margin_Option_place() + self.mdx_batch_size_Label_place() + self.mdx_batch_size_Option_place() + self.compensate_Label_place() + self.compensate_Option_place() general_shared_Buttons_place() stem_save_Options_place() no_ensemble_shared() @@ -4276,9 +4397,9 @@ class MainWindow(TkinterDnD.Tk if is_dnd_compatible else tk.Tk): self.active_processing_thread.start() def process_button_init(self): - self.command_Text.clear() self.conversion_Button_Text_var.set(WAIT_PROCESSING) self.conversion_Button.configure(state=tk.DISABLED) + self.command_Text.clear() def process_get_baseText(self, total_files, file_num): """Create the base text for the command widget""" @@ -4323,7 +4444,7 @@ class MainWindow(TkinterDnD.Tk if is_dnd_compatible else tk.Tk): self.conversion_Button_Text_var.set(START_PROCESSING) self.conversion_Button.configure(state=tk.NORMAL) self.progress_bar_main_var.set(0) - + if error: error_message_box_text = f'{error_dialouge(error)}{ERROR_OCCURED[1]}' confirm = tk.messagebox.askyesno(parent=root, @@ -4564,7 +4685,8 @@ class MainWindow(TkinterDnD.Tk if is_dnd_compatible else tk.Tk): audio_file_base = f"{file_num}_{os.path.splitext(os.path.basename(audio_file))[0]}" audio_file_base = audio_file_base if not self.is_testing_audio_var.get() or is_ensemble else f"{round(time.time())}_{audio_file_base}" audio_file_base = audio_file_base if not is_ensemble else f"{audio_file_base}_{current_model.model_basename}" - audio_file_base = audio_file_base if not self.is_add_model_name_var.get() else f"{audio_file_base}_{current_model.model_basename}" + if not is_ensemble: + audio_file_base = audio_file_base if not self.is_add_model_name_var.get() else f"{audio_file_base}_{current_model.model_basename}" if self.is_create_model_folder_var.get() and not is_ensemble: export_path = os.path.join(Path(self.export_path_var.get()), current_model.model_basename, os.path.splitext(os.path.basename(audio_file))[0]) @@ -4583,12 +4705,14 @@ class MainWindow(TkinterDnD.Tk if is_dnd_compatible else tk.Tk): 'list_all_models': self.all_models, 'is_ensemble_master': is_ensemble, 'is_4_stem_ensemble': True if self.ensemble_main_stem_var.get() == FOUR_STEM_ENSEMBLE and is_ensemble else False} + if current_model.process_method == VR_ARCH_TYPE: seperator = SeperateVR(current_model, process_data) if current_model.process_method == MDX_ARCH_TYPE: seperator = SeperateMDX(current_model, process_data) if current_model.process_method == DEMUCS_ARCH_TYPE: seperator = SeperateDemucs(current_model, process_data) + seperator.seperate() if is_ensemble: @@ -4630,7 +4754,7 @@ class MainWindow(TkinterDnD.Tk if is_dnd_compatible else tk.Tk): playsound(COMPLETE_CHIME) if self.is_task_complete_var.get() else None self.process_end() - + except Exception as e: self.error_log_var.set("{}{}".format(error_text(self.chosen_process_method_var.get(), e), self.get_settings_list())) self.command_Text.write(f'\n\n{PROCESS_FAILED}') @@ -4658,7 +4782,8 @@ class MainWindow(TkinterDnD.Tk if is_dnd_compatible else tk.Tk): for key, value in DEFAULT_DATA.items(): if not key in data.keys(): data = {**data, **{key:value}} - + data['batch_size'] = DEF_OPT + ## ADD_BUTTON self.chosen_process_method_var = tk.StringVar(value=data['chosen_process_method']) @@ -4691,6 +4816,7 @@ class MainWindow(TkinterDnD.Tk if is_dnd_compatible else tk.Tk): self.chunks_demucs_var = tk.StringVar(value=data['chunks_demucs']) self.margin_demucs_var = tk.StringVar(value=data['margin_demucs']) self.is_chunk_demucs_var = tk.BooleanVar(value=data['is_chunk_demucs']) + self.is_chunk_mdxnet_var = tk.BooleanVar(value=data['is_chunk_mdxnet']) self.is_primary_stem_only_Demucs_var = tk.BooleanVar(value=data['is_primary_stem_only_Demucs']) self.is_secondary_stem_only_Demucs_var = tk.BooleanVar(value=data['is_secondary_stem_only_Demucs']) self.is_split_mode_var = tk.BooleanVar(value=data['is_split_mode']) @@ -4715,6 +4841,8 @@ class MainWindow(TkinterDnD.Tk if is_dnd_compatible else tk.Tk): self.compensate_var = tk.StringVar(value=data['compensate']) self.is_denoise_var = tk.BooleanVar(value=data['is_denoise']) self.is_invert_spec_var = tk.BooleanVar(value=data['is_invert_spec']) + self.is_mixer_mode_var = tk.BooleanVar(value=data['is_mixer_mode']) + self.mdx_batch_size_var = tk.StringVar(value=data['mdx_batch_size']) self.mdx_voc_inst_secondary_model_var = tk.StringVar(value=data['mdx_voc_inst_secondary_model']) self.mdx_other_secondary_model_var = tk.StringVar(value=data['mdx_other_secondary_model']) self.mdx_bass_secondary_model_var = tk.StringVar(value=data['mdx_bass_secondary_model']) @@ -4765,8 +4893,11 @@ class MainWindow(TkinterDnD.Tk if is_dnd_compatible else tk.Tk): for key, value in DEFAULT_DATA.items(): if not key in loaded_setting.keys(): loaded_setting = {**loaded_setting, **{key:value}} + loaded_setting['batch_size'] = DEF_OPT - if not process_method or process_method == VR_ARCH_PM: + is_ensemble = True if process_method == ENSEMBLE_MODE else False + + if not process_method or process_method == VR_ARCH_PM or is_ensemble: self.vr_model_var.set(loaded_setting['vr_model']) self.aggression_setting_var.set(loaded_setting['aggression_setting']) self.window_size_var.set(loaded_setting['window_size']) @@ -4787,7 +4918,7 @@ class MainWindow(TkinterDnD.Tk if is_dnd_compatible else tk.Tk): self.vr_bass_secondary_model_scale_var.set(loaded_setting['vr_bass_secondary_model_scale']) self.vr_drums_secondary_model_scale_var.set(loaded_setting['vr_drums_secondary_model_scale']) - if not process_method or process_method == DEMUCS_ARCH_TYPE: + if not process_method or process_method == DEMUCS_ARCH_TYPE or is_ensemble: self.demucs_model_var.set(loaded_setting['demucs_model']) self.segment_var.set(loaded_setting['segment']) self.overlap_var.set(loaded_setting['overlap']) @@ -4795,6 +4926,7 @@ class MainWindow(TkinterDnD.Tk if is_dnd_compatible else tk.Tk): self.chunks_demucs_var.set(loaded_setting['chunks_demucs']) self.margin_demucs_var.set(loaded_setting['margin_demucs']) self.is_chunk_demucs_var.set(loaded_setting['is_chunk_demucs']) + self.is_chunk_mdxnet_var.set(loaded_setting['is_chunk_mdxnet']) self.is_primary_stem_only_Demucs_var.set(loaded_setting['is_primary_stem_only_Demucs']) self.is_secondary_stem_only_Demucs_var.set(loaded_setting['is_secondary_stem_only_Demucs']) self.is_split_mode_var.set(loaded_setting['is_split_mode']) @@ -4814,13 +4946,15 @@ class MainWindow(TkinterDnD.Tk if is_dnd_compatible else tk.Tk): self.is_demucs_pre_proc_model_activate_var.set(data['is_demucs_pre_proc_model_activate']) self.is_demucs_pre_proc_model_inst_mix_var.set(data['is_demucs_pre_proc_model_inst_mix']) - if not process_method or process_method == MDX_ARCH_TYPE: + if not process_method or process_method == MDX_ARCH_TYPE or is_ensemble: self.mdx_net_model_var.set(loaded_setting['mdx_net_model']) self.chunks_var.set(loaded_setting['chunks']) self.margin_var.set(loaded_setting['margin']) self.compensate_var.set(loaded_setting['compensate']) self.is_denoise_var.set(loaded_setting['is_denoise']) self.is_invert_spec_var.set(loaded_setting['is_invert_spec']) + self.is_mixer_mode_var.set(loaded_setting['is_mixer_mode']) + self.mdx_batch_size_var.set(loaded_setting['mdx_batch_size']) self.mdx_voc_inst_secondary_model_var.set(loaded_setting['mdx_voc_inst_secondary_model']) self.mdx_other_secondary_model_var.set(loaded_setting['mdx_other_secondary_model']) self.mdx_bass_secondary_model_var.set(loaded_setting['mdx_bass_secondary_model']) @@ -4831,7 +4965,7 @@ class MainWindow(TkinterDnD.Tk if is_dnd_compatible else tk.Tk): self.mdx_bass_secondary_model_scale_var.set(loaded_setting['mdx_bass_secondary_model_scale']) self.mdx_drums_secondary_model_scale_var.set(loaded_setting['mdx_drums_secondary_model_scale']) - if not process_method: + if not process_method or is_ensemble: self.is_save_all_outputs_ensemble_var.set(loaded_setting['is_save_all_outputs_ensemble']) self.is_append_ensemble_name_var.set(loaded_setting['is_append_ensemble_name']) self.chosen_audio_tool_var.set(loaded_setting['chosen_audio_tool']) @@ -4889,6 +5023,7 @@ class MainWindow(TkinterDnD.Tk if is_dnd_compatible else tk.Tk): 'chunks_demucs': self.chunks_demucs_var.get(), 'margin_demucs': self.margin_demucs_var.get(), 'is_chunk_demucs': self.is_chunk_demucs_var.get(), + 'is_chunk_mdxnet': self.is_chunk_mdxnet_var.get(), 'is_primary_stem_only_Demucs': self.is_primary_stem_only_Demucs_var.get(), 'is_secondary_stem_only_Demucs': self.is_secondary_stem_only_Demucs_var.get(), 'is_split_mode': self.is_split_mode_var.get(), @@ -4910,7 +5045,9 @@ class MainWindow(TkinterDnD.Tk if is_dnd_compatible else tk.Tk): 'margin': self.margin_var.get(), 'compensate': self.compensate_var.get(), 'is_denoise': self.is_denoise_var.get(), - 'is_invert_spec': self.is_invert_spec_var.get(), + 'is_invert_spec': self.is_invert_spec_var.get(), + 'is_mixer_mode': self.is_mixer_mode_var.get(), + 'mdx_batch_size':self.mdx_batch_size_var.get(), 'mdx_voc_inst_secondary_model': self.mdx_voc_inst_secondary_model_var.get(), 'mdx_other_secondary_model': self.mdx_other_secondary_model_var.get(), 'mdx_bass_secondary_model': self.mdx_bass_secondary_model_var.get(), @@ -4983,7 +5120,7 @@ class MainWindow(TkinterDnD.Tk if is_dnd_compatible else tk.Tk): settings_list = '\n'.join(''.join(f"{key}: {value}") for key, value in settings_dict.items() if not key == 'user_code') return f"\nFull Application Settings:\n\n{settings_list}" - + def secondary_stem(stem): """Determines secondary stem""" diff --git a/separate.py b/separate.py index fb39f4e..23e8a12 100644 --- a/separate.py +++ b/separate.py @@ -12,6 +12,7 @@ from lib_v5.vr_network import nets_new #from lib_v5.vr_network.model_param_init import ModelParameters from pathlib import Path from gui_data.constants import * +from gui_data.error_handling import * import audioread import gzip import librosa @@ -23,6 +24,8 @@ import torch import warnings import pydub import soundfile as sf +import traceback +import lib_v5.mdxnet as MdxnetSet if TYPE_CHECKING: from UVR import ModelData @@ -46,7 +49,10 @@ class SeperateAttributes: self.is_4_stem_ensemble = process_data['is_4_stem_ensemble'] self.list_all_models = process_data['list_all_models'] self.process_iteration = process_data['process_iteration'] + self.mixer_path = model_data.mixer_path self.model_samplerate = model_data.model_samplerate + self.model_capacity = model_data.model_capacity + self.is_vr_51_model = model_data.is_vr_51_model self.is_pre_proc_model = model_data.is_pre_proc_model self.is_secondary_model_activated = model_data.is_secondary_model_activated if not self.is_pre_proc_model else False self.is_secondary_model = model_data.is_secondary_model if not self.is_pre_proc_model else True @@ -67,6 +73,7 @@ class SeperateAttributes: self.primary_stem = model_data.primary_stem # self.secondary_stem = model_data.secondary_stem # self.is_invert_spec = model_data.is_invert_spec # + self.is_mixer_mode = model_data.is_mixer_mode # self.secondary_model_scale = model_data.secondary_model_scale # self.is_demucs_pre_proc_model_inst_mix = model_data.is_demucs_pre_proc_model_inst_mix # self.primary_source_map = {} @@ -94,21 +101,24 @@ class SeperateAttributes: self.is_secondary_stem_only = True if self.secondary_stem == INST_STEM else False if model_data.process_method == MDX_ARCH_TYPE: + self.is_mdx_ckpt = model_data.is_mdx_ckpt self.primary_model_name, self.primary_sources = self.cached_source_callback(MDX_ARCH_TYPE, model_name=self.model_basename) self.is_denoise = model_data.is_denoise + self.mdx_batch_size = model_data.mdx_batch_size self.compensate = model_data.compensate self.dim_f, self.dim_t = model_data.mdx_dim_f_set, 2**model_data.mdx_dim_t_set self.n_fft = model_data.mdx_n_fft_scale_set self.chunks = model_data.chunks self.margin = model_data.margin - self.hop = 1024 - self.n_bins = self.n_fft//2+1 - self.chunk_size = self.hop * (self.dim_t-1) - self.window = torch.hann_window(window_length=self.n_fft, periodic=False).to(cpu) + self.adjust = 1 self.dim_c = 4 - out_c = self.dim_c - self.freq_pad = torch.zeros([1, out_c, self.n_bins-self.dim_f, self.dim_t]).to(cpu) - + self.hop = 1024 + + if self.is_gpu_conversion >= 0 and torch.cuda.is_available(): + self.device, self.run_type = torch.device('cuda:0'), ['CUDAExecutionProvider'] + else: + self.device, self.run_type = torch.device('cpu'), ['CPUExecutionProvider'] + if model_data.process_method == DEMUCS_ARCH_TYPE: self.demucs_stems = model_data.demucs_stems if not main_process_method in [MDX_ARCH_TYPE, VR_ARCH_TYPE] else None self.secondary_model_4_stem = model_data.secondary_model_4_stem @@ -152,15 +162,14 @@ class SeperateAttributes: self.is_post_process = model_data.is_post_process self.is_gpu_conversion = model_data.is_gpu_conversion self.batch_size = model_data.batch_size - self.crop_size = model_data.crop_size self.window_size = model_data.window_size self.input_high_end_h = None self.post_process_threshold = model_data.post_process_threshold self.aggressiveness = {'value': model_data.aggression_setting, 'split_bin': self.mp.param['band'][1]['crop_stop'], 'aggr_correction': self.mp.param.get('aggr_correction')} - - def start_inference(self): + + def start_inference_console_write(self): if self.is_secondary_model and not self.is_pre_proc_model: self.write_to_console(INFERENCE_STEP_2_SEC(self.process_method, self.model_basename)) @@ -168,7 +177,7 @@ class SeperateAttributes: if self.is_pre_proc_model: self.write_to_console(INFERENCE_STEP_2_PRE(self.process_method, self.model_basename)) - def running_inference(self, is_no_write=False): + def running_inference_console_write(self, is_no_write=False): self.write_to_console(DONE, base_text='') if not is_no_write else None self.set_progress_bar(0.05) if not is_no_write else None @@ -180,6 +189,15 @@ class SeperateAttributes: else: self.write_to_console(INFERENCE_STEP_1) + def running_inference_progress_bar(self, length, is_match_mix=False): + if not is_match_mix: + self.progress_value += 1 + + if (0.8/length*self.progress_value) >= 0.8: + length = self.progress_value + 1 + + self.set_progress_bar(0.1, (0.8/length*self.progress_value)) + def load_cached_sources(self, is_4_stem_demucs=False): if self.is_secondary_model and not self.is_pre_proc_model: @@ -222,31 +240,51 @@ class SeperateAttributes: self.write_to_console(DONE, base_text='') self.set_progress_bar(0.95) + def run_mixer(self, mix, sources): + try: + if self.is_mixer_mode and len(sources) == 4: + mixer = MdxnetSet.Mixer(self.device, self.mixer_path).eval() + with torch.no_grad(): + mix = torch.tensor(mix, dtype=torch.float32) + sources_ = torch.tensor(sources).detach() + x = torch.cat([sources_, mix.unsqueeze(0)], 0) + sources_ = mixer(x) + final_source = np.array(sources_) + else: + final_source = sources + except Exception as e: + error_name = f'{type(e).__name__}' + traceback_text = ''.join(traceback.format_tb(e.__traceback__)) + message = f'{error_name}: "{e}"\n{traceback_text}"' + print('Mixer Failed: ', message) + final_source = sources + + return final_source + class SeperateMDX(SeperateAttributes): def seperate(self): - samplerate = 44100 - + if self.primary_model_name == self.model_basename and self.primary_sources: self.primary_source, self.secondary_source = self.load_cached_sources() else: - self.start_inference() + self.start_inference_console_write() - if self.is_gpu_conversion >= 0: - self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') - run_type = ['CUDAExecutionProvider'] if torch.cuda.is_available() else ['CPUExecutionProvider'] + if self.is_mdx_ckpt: + model_params = torch.load(self.model_path, map_location=lambda storage, loc: storage)['hyper_parameters'] + self.dim_c, self.hop = model_params['dim_c'], model_params['hop_length'] + separator = MdxnetSet.ConvTDFNet(**model_params) + self.model_run = separator.load_from_checkpoint(self.model_path).to(self.device).eval() else: - self.device = torch.device('cpu') - run_type = ['CPUExecutionProvider'] + ort_ = ort.InferenceSession(self.model_path, providers=self.run_type) + self.model_run = lambda spek:ort_.run(None, {'input': spek.cpu().numpy()})[0] - self.onnx_model = ort.InferenceSession(self.model_path, providers=run_type) - - self.running_inference() + self.initialize_model_settings() + self.running_inference_console_write() mdx_net_cut = True if self.primary_stem in MDX_NET_FREQ_CUT else False mix, raw_mix, samplerate = prepare_mix(self.audio_file, self.chunks, self.margin, mdx_net_cut=mdx_net_cut) - - source = self.demix_base(mix) + source = self.demix_base(mix, is_ckpt=self.is_mdx_ckpt)[0] self.write_to_console(DONE, base_text='') if self.is_secondary_model_activated: @@ -257,7 +295,7 @@ class SeperateMDX(SeperateAttributes): self.write_to_console(f'{SAVING_STEM[0]}{self.primary_stem}{SAVING_STEM[1]}') if not self.is_secondary_model else None primary_stem_path = os.path.join(self.export_path, f'{self.audio_file_base}_({self.primary_stem}).wav') if not isinstance(self.primary_source, np.ndarray): - self.primary_source = spec_utils.normalize(source[0], self.is_normalization).T + self.primary_source = spec_utils.normalize(source, self.is_normalization).T self.primary_source_map = {self.primary_stem: self.primary_source} self.write_audio(primary_stem_path, self.primary_source, samplerate, self.secondary_source_primary) @@ -266,7 +304,7 @@ class SeperateMDX(SeperateAttributes): secondary_stem_path = os.path.join(self.export_path, f'{self.audio_file_base}_({self.secondary_stem}).wav') if not isinstance(self.secondary_source, np.ndarray): raw_mix = self.demix_base(raw_mix, is_match_mix=True)[0] if mdx_net_cut else raw_mix - self.secondary_source, raw_mix = spec_utils.normalize_two_stem(source[0]*self.compensate, raw_mix, self.is_normalization) + self.secondary_source, raw_mix = spec_utils.normalize_two_stem(source*self.compensate, raw_mix, self.is_normalization) if self.is_invert_spec: self.secondary_source = spec_utils.invert_stem(raw_mix, self.secondary_source) @@ -277,7 +315,6 @@ class SeperateMDX(SeperateAttributes): self.write_audio(secondary_stem_path, self.secondary_source, samplerate, self.secondary_source_secondary) torch.cuda.empty_cache() - secondary_sources = {**self.primary_source_map, **self.secondary_source_map} self.cache_source(secondary_sources) @@ -285,53 +322,73 @@ class SeperateMDX(SeperateAttributes): if self.is_secondary_model: return secondary_sources - def demix_base(self, mix, is_match_mix=False): - chunked_sources = [] + def initialize_model_settings(self): + self.n_bins = self.n_fft//2+1 + self.trim = self.n_fft//2 + self.chunk_size = self.hop * (self.dim_t-1) + self.window = torch.hann_window(window_length=self.n_fft, periodic=False).to(self.device) + self.freq_pad = torch.zeros([1, self.dim_c, self.n_bins-self.dim_f, self.dim_t]).to(self.device) + self.gen_size = self.chunk_size-2*self.trim - for slice in mix: - self.progress_value += 1 - self.set_progress_bar(0.1, (0.8/len(mix)*self.progress_value)) if not is_match_mix else None - cmix = mix[slice] - sources = [] + def initialize_mix(self, mix, is_ckpt=False): + if is_ckpt: + pad = self.gen_size + self.trim - ((mix.shape[-1]) % self.gen_size) + mixture = np.concatenate((np.zeros((2, self.trim), dtype='float32'),mix, np.zeros((2, pad), dtype='float32')), 1) + num_chunks = mixture.shape[-1] // self.gen_size + mix_waves = [mixture[:, i * self.gen_size: i * self.gen_size + self.chunk_size] for i in range(num_chunks)] + else: mix_waves = [] - n_sample = cmix.shape[1] - trim = self.n_fft//2 - gen_size = self.chunk_size-2*trim - pad = gen_size - n_sample%gen_size - mix_p = np.concatenate((np.zeros((2,trim)), cmix, np.zeros((2,pad)), np.zeros((2,trim))), 1) + n_sample = mix.shape[1] + pad = self.gen_size - n_sample%self.gen_size + mix_p = np.concatenate((np.zeros((2,self.trim)), mix, np.zeros((2,pad)), np.zeros((2,self.trim))), 1) i = 0 while i < n_sample + pad: waves = np.array(mix_p[:, i:i+self.chunk_size]) mix_waves.append(waves) - i += gen_size - mix_waves = torch.tensor(mix_waves, dtype=torch.float32).to(cpu) - with torch.no_grad(): - _ort = self.onnx_model if not is_match_mix else None - adjust = 1 - spek = self.stft(mix_waves)*adjust + i += self.gen_size - if not is_match_mix: - if self.is_denoise: - spec_pred = -_ort.run(None, {'input': -spek.cpu().numpy()})[0]*0.5+_ort.run(None, {'input': spek.cpu().numpy()})[0]*0.5 - else: - spec_pred = _ort.run(None, {'input': spek.cpu().numpy()})[0] - else: - spec_pred = spek.cpu().numpy() + mix_waves = torch.tensor(mix_waves, dtype=torch.float32).to(self.device) - tar_waves = self.istft(torch.tensor(spec_pred))#.cpu() - tar_signal = tar_waves[:,:,trim:-trim].transpose(0,1).reshape(2, -1).numpy()[:, :-pad] + return mix_waves, pad + + def demix_base(self, mix, is_ckpt=False, is_match_mix=False): + chunked_sources = [] + for slice in mix: + sources = [] + tar_waves_ = [] + mix_p = mix[slice] + mix_waves, pad = self.initialize_mix(mix_p, is_ckpt=is_ckpt) + mix_waves = mix_waves.split(self.mdx_batch_size) + pad = mix_p.shape[-1] if is_ckpt else -pad + with torch.no_grad(): + for mix_wave in mix_waves: + self.running_inference_progress_bar(len(mix)*len(mix_waves), is_match_mix=is_match_mix) + tar_waves = self.run_model(mix_wave, is_ckpt=is_ckpt, is_match_mix=is_match_mix) + tar_waves_.append(tar_waves) + tar_waves_ = np.vstack(tar_waves_)[:, :, self.trim:-self.trim] if is_ckpt else tar_waves_ + tar_waves = np.concatenate(tar_waves_, axis=-1)[:, :pad] start = 0 if slice == 0 else self.margin - end = None if slice == list(mix.keys())[::-1][0] else -self.margin - if self.margin == 0: - end = None - sources.append(tar_signal[:,start:end]*(1/adjust)) + end = None if slice == list(mix.keys())[::-1][0] or self.margin == 0 else -self.margin + sources.append(tar_waves[:,start:end]*(1/self.adjust)) chunked_sources.append(sources) sources = np.concatenate(chunked_sources, axis=-1) - - if not is_match_mix: - del self.onnx_model - + return sources + + def run_model(self, mix, is_ckpt=False, is_match_mix=False): + + spek = self.stft(mix.to(self.device))*self.adjust + spek[:, :, :3, :] *= 0 + + if is_match_mix: + spec_pred = spek.cpu().numpy() + else: + spec_pred = -self.model_run(-spek)*0.5+self.model_run(spek)*0.5 if self.is_denoise else self.model_run(spek) + + if is_ckpt: + return self.istft(spec_pred).cpu().detach().numpy() + else: + return self.istft(torch.tensor(spec_pred).to(self.device)).to(cpu)[:,:,self.trim:-self.trim].transpose(0,1).reshape(2, -1).numpy() def stft(self, x): x = x.reshape([-1, self.chunk_size]) @@ -343,11 +400,10 @@ class SeperateMDX(SeperateAttributes): def istft(self, x, freq_pad=None): freq_pad = self.freq_pad.repeat([x.shape[0],1,1,1]) if freq_pad is None else freq_pad x = torch.cat([x, freq_pad], -2) - c = 2 - x = x.reshape([-1,c,2,self.n_bins,self.dim_t]).reshape([-1,2,self.n_bins,self.dim_t]) + x = x.reshape([-1,2,2,self.n_bins,self.dim_t]).reshape([-1,2,self.n_bins,self.dim_t]) x = x.permute([0,2,3,1]) x = torch.istft(x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True) - return x.reshape([-1,c,self.chunk_size]) + return x.reshape([-1,2,self.chunk_size]) class SeperateDemucs(SeperateAttributes): @@ -371,7 +427,7 @@ class SeperateDemucs(SeperateAttributes): source = self.primary_sources self.load_cached_sources(is_4_stem_demucs=True) else: - self.start_inference() + self.start_inference_console_write() if self.is_gpu_conversion >= 0: self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') @@ -405,22 +461,25 @@ class SeperateDemucs(SeperateAttributes): mix_no_voc = process_secondary_model(self.pre_proc_model, self.process_data, is_pre_proc_model=True) inst_mix, inst_raw_mix, inst_samplerate = prepare_mix(mix_no_voc[INST_STEM], self.chunks_demucs, self.margin_demucs) self.process_iteration() - self.running_inference(is_no_write=is_no_write) + self.running_inference_console_write(is_no_write=is_no_write) inst_source = self.demix_demucs(inst_mix) + inst_source = self.run_mixer(inst_raw_mix, inst_source) self.process_iteration() - self.running_inference(is_no_write=is_no_write) if not self.pre_proc_model else None + self.running_inference_console_write(is_no_write=is_no_write) if not self.pre_proc_model else None mix, raw_mix, samplerate = prepare_mix(self.audio_file, self.chunks_demucs, self.margin_demucs) if self.primary_model_name == self.model_basename and isinstance(self.primary_sources, np.ndarray) and self.pre_proc_model: source = self.primary_sources else: source = self.demix_demucs(mix) + source = self.run_mixer(raw_mix, source) self.write_to_console(DONE, base_text='') del self.demucs - + torch.cuda.empty_cache() + if isinstance(inst_source, np.ndarray): source_reshape = spec_utils.reshape_sources(inst_source[self.demucs_source_map[VOCAL_STEM]], source[self.demucs_source_map[VOCAL_STEM]]) inst_source[self.demucs_source_map[VOCAL_STEM]] = source_reshape @@ -431,6 +490,7 @@ class SeperateDemucs(SeperateAttributes): self.demucs_source_map = DEMUCS_2_SOURCE_MAPPER else: self.demucs_source_map = DEMUCS_6_SOURCE_MAPPER if len(source) == 6 else DEMUCS_4_SOURCE_MAPPER + if len(source) == 6 and self.process_data['is_ensemble_master'] or len(source) == 6 and self.is_secondary_model: is_no_piano_guitar = True six_stem_other_source = list(source) @@ -445,7 +505,6 @@ class SeperateDemucs(SeperateAttributes): self.cache_source(source) for stem_name, stem_value in self.demucs_source_map.items(): - if self.is_secondary_model_activated and not self.is_secondary_model and not stem_value >= 4: if self.secondary_model_4_stem[stem_value]: model_scale = self.secondary_model_4_stem_scale[stem_value] @@ -520,9 +579,7 @@ class SeperateDemucs(SeperateAttributes): if self.is_demucs_pre_proc_model_inst_mix and self.pre_proc_model and not self.is_4_stem_ensemble: secondary_save(f"{self.secondary_stem} {INST_STEM}", source, raw_mixture=inst_raw_mix, is_inst_mixture=True) - - torch.cuda.empty_cache() - + secondary_sources = {**self.primary_source_map, **self.secondary_source_map} self.cache_source(secondary_sources) @@ -583,11 +640,10 @@ class SeperateDemucs(SeperateAttributes): class SeperateVR(SeperateAttributes): def seperate(self): - if self.primary_model_name == self.model_basename and self.primary_sources: self.primary_source, self.secondary_source = self.load_cached_sources() else: - self.start_inference() + self.start_inference_console_write() if self.is_gpu_conversion >= 0: if OPERATING_SYSTEM == 'Darwin': device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu') @@ -595,32 +651,27 @@ class SeperateVR(SeperateAttributes): device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') else: device = torch.device('cpu') - + nn_arch_sizes = [ 31191, # default - 33966, 56817, 218409, 123821, 123812, 129605, 537238, 537227] + 33966, 56817, 123821, 123812, 129605, 218409, 537238, 537227] vr_5_1_models = [56817, 218409] - model_size = math.ceil(os.stat(self.model_path).st_size / 1024) - nn_architecture = min(nn_arch_sizes, key=lambda x:abs(x-model_size)) + nn_arch_size = min(nn_arch_sizes, key=lambda x:abs(x-model_size)) - if nn_architecture in vr_5_1_models: - model = nets_new.CascadedNet(self.mp.param['bins'] * 2, nn_architecture) - inference = self.inference_vr_new + if nn_arch_size in vr_5_1_models or self.is_vr_51_model: + self.model_run = nets_new.CascadedNet(self.mp.param['bins'] * 2, nn_arch_size, nout=self.model_capacity[0], nout_lstm=self.model_capacity[1]) else: - model = nets.determine_model_capacity(self.mp.param['bins'] * 2, nn_architecture) - inference = self.inference_vr + self.model_run = nets.determine_model_capacity(self.mp.param['bins'] * 2, nn_arch_size) + + self.model_run.load_state_dict(torch.load(self.model_path, map_location=cpu)) + self.model_run.to(device) - model.load_state_dict(torch.load(self.model_path, map_location=device)) - model.to(device) - - self.running_inference() - - y_spec, v_spec = inference(self.loading_mix(), device, model, self.aggressiveness) + self.running_inference_console_write() + + y_spec, v_spec = self.inference_vr(self.loading_mix(), device, self.aggressiveness) self.write_to_console(DONE, base_text='') - del model - if self.is_secondary_model_activated: if self.secondary_model: self.secondary_source_primary, self.secondary_source_secondary = process_secondary_model(self.secondary_model, self.process_data, main_process_method=self.process_method) @@ -649,9 +700,8 @@ class SeperateVR(SeperateAttributes): self.secondary_source_map = {self.secondary_stem: self.secondary_source} self.write_audio(secondary_stem_path, self.secondary_source, 44100, self.secondary_source_secondary) - + torch.cuda.empty_cache() - secondary_sources = {**self.primary_source_map, **self.secondary_source_map} self.cache_source(secondary_sources) @@ -696,80 +746,20 @@ class SeperateVR(SeperateAttributes): return X_spec - def inference_vr(self, X_spec, device, model, aggressiveness): - - def _execute(X_mag_pad, roi_size, n_window, device, model, aggressiveness): - model.eval() - - total_iterations = sum([n_window]) if not self.is_tta else sum([n_window])*2 - - with torch.no_grad(): - preds = [] - - for i in range(n_window): - self.progress_value +=1 - self.set_progress_bar(0.1, 0.8/total_iterations*self.progress_value) - start = i * roi_size - X_mag_window = X_mag_pad[None, :, :, start:start + self.window_size] - X_mag_window = torch.from_numpy(X_mag_window).to(device) - pred = model.predict(X_mag_window, aggressiveness) - pred = pred.detach().cpu().numpy() - preds.append(pred[0]) - - pred = np.concatenate(preds, axis=2) - return pred - - X_mag, X_phase = spec_utils.preprocess(X_spec) - coef = X_mag.max() - X_mag_pre = X_mag / coef - n_frame = X_mag_pre.shape[2] - pad_l, pad_r, roi_size = spec_utils.make_padding(n_frame, self.window_size, model.offset) - n_window = int(np.ceil(n_frame / roi_size)) - X_mag_pad = np.pad( - X_mag_pre, ((0, 0), (0, 0), (pad_l, pad_r)), mode='constant') - pred = _execute(X_mag_pad, roi_size, n_window, device, model, aggressiveness) - pred = pred[:, :, :n_frame] - - if self.is_tta: - pad_l += roi_size // 2 - pad_r += roi_size // 2 - n_window += 1 - X_mag_pad = np.pad(X_mag_pre, ((0, 0), (0, 0), (pad_l, pad_r)), mode='constant') - pred_tta = _execute(X_mag_pad, roi_size, n_window, device, model, aggressiveness) - pred_tta = pred_tta[:, :, roi_size // 2:] - pred_tta = pred_tta[:, :, :n_frame] - pred, X_mag, X_phase = (pred + pred_tta) * 0.5 * coef, X_mag, np.exp(1.j * X_phase) - else: - pred, X_mag, X_phase = pred * coef, X_mag, np.exp(1.j * X_phase) - - if self.is_post_process: - pred_inv = np.clip(X_mag - pred, 0, np.inf) - pred = spec_utils.mask_silence(pred, pred_inv, thres=self.post_process_threshold) - - y_spec = pred * X_phase - v_spec = X_spec - y_spec - - return y_spec, v_spec - - def inference_vr_new(self, X_spec, device, model, aggressiveness): - + def inference_vr(self, X_spec, device, aggressiveness): def _execute(X_mag_pad, roi_size): - X_dataset = [] - patches = (X_mag_pad.shape[2] - 2 * model.offset) // roi_size + patches = (X_mag_pad.shape[2] - 2 * self.model_run.offset) // roi_size total_iterations = patches//self.batch_size if not self.is_tta else (patches//self.batch_size)*2 - for i in range(patches): start = i * roi_size - X_mag_crop = X_mag_pad[:, :, start:start + self.crop_size] - X_dataset.append(X_mag_crop) + X_mag_window = X_mag_pad[:, :, start:start + self.window_size] + X_dataset.append(X_mag_window) X_dataset = np.asarray(X_dataset) - model.eval() - + self.model_run.eval() with torch.no_grad(): mask = [] - # To reduce the overhead, dataloader is not used. for i in range(0, patches, self.batch_size): self.progress_value += 1 if self.progress_value >= total_iterations: @@ -777,33 +767,37 @@ class SeperateVR(SeperateAttributes): self.set_progress_bar(0.1, 0.8/total_iterations*self.progress_value) X_batch = X_dataset[i: i + self.batch_size] X_batch = torch.from_numpy(X_batch).to(device) - pred = model.predict_mask(X_batch) + pred = self.model_run.predict_mask(X_batch) + if not pred.size()[3] > 0: + raise Exception(ERROR_MAPPER[WINDOW_SIZE_ERROR]) pred = pred.detach().cpu().numpy() pred = np.concatenate(pred, axis=2) mask.append(pred) - + if len(mask) == 0: + raise Exception(ERROR_MAPPER[WINDOW_SIZE_ERROR]) + mask = np.concatenate(mask, axis=2) - return mask - - def postprocess(mask, X_mag, X_phase, aggressiveness): + + def postprocess(mask, X_mag, X_phase): - if self.primary_stem == VOCAL_STEM: - mask = (1.0 - spec_utils.adjust_aggr(mask, True, aggressiveness)) - else: - mask = spec_utils.adjust_aggr(mask, False, aggressiveness) + is_non_accom_stem = False + for stem in NON_ACCOM_STEMS: + if stem == self.primary_stem: + is_non_accom_stem = True + + mask = spec_utils.adjust_aggr(mask, is_non_accom_stem, aggressiveness) if self.is_post_process: - mask = spec_utils.merge_artifacts(mask) + mask = spec_utils.merge_artifacts(mask, thres=self.post_process_threshold) y_spec = mask * X_mag * np.exp(1.j * X_phase) v_spec = (1 - mask) * X_mag * np.exp(1.j * X_phase) return y_spec, v_spec - X_mag, X_phase = spec_utils.preprocess(X_spec) n_frame = X_mag.shape[2] - pad_l, pad_r, roi_size = spec_utils.make_padding(n_frame, self.crop_size, model.offset) + pad_l, pad_r, roi_size = spec_utils.make_padding(n_frame, self.window_size, self.model_run.offset) X_mag_pad = np.pad(X_mag, ((0, 0), (0, 0), (pad_l, pad_r)), mode='constant') X_mag_pad /= X_mag_pad.max() mask = _execute(X_mag_pad, roi_size) @@ -819,7 +813,7 @@ class SeperateVR(SeperateAttributes): else: mask = mask[:, :, :n_frame] - y_spec, v_spec = postprocess(mask, X_mag, X_phase, aggressiveness) + y_spec, v_spec = postprocess(mask, X_mag, X_phase) return y_spec, v_spec @@ -889,6 +883,7 @@ def prepare_mix(mix, chunk_set, margin_set, mdx_net_cut=False, is_missing_mix=Fa margin = margin_set chunk_size = chunk_set*44100 assert not margin == 0, 'margin cannot be zero!' + if margin > chunk_size: margin = chunk_size if chunk_set == 0 or samples < chunk_size: @@ -941,4 +936,4 @@ def save_format(audio_path, save_format, mp3_bit_set): try: os.remove(audio_path) except Exception as e: - print(e) + print(e) \ No newline at end of file