From 3cf54b53f6b175d7e579d8a21062ff59aa133589 Mon Sep 17 00:00:00 2001 From: Anjok07 <68268275+Anjok07@users.noreply.github.com> Date: Mon, 9 Nov 2020 04:32:56 -0600 Subject: [PATCH] Add files via upload --- VocalRemover.py | 898 ++++++++++++++++++++++++++++++++++++++++++++++++ inference_v2.py | 489 ++++++++++++++++++++++++++ inference_v4.py | 525 ++++++++++++++++++++++++++++ 3 files changed, 1912 insertions(+) create mode 100644 VocalRemover.py create mode 100644 inference_v2.py create mode 100644 inference_v4.py diff --git a/VocalRemover.py b/VocalRemover.py new file mode 100644 index 0000000..37f67f2 --- /dev/null +++ b/VocalRemover.py @@ -0,0 +1,898 @@ +# GUI modules +import tkinter as tk +import tkinter.ttk as ttk +import tkinter.messagebox +import tkinter.filedialog +import tkinter.font +from datetime import datetime +# Images +from PIL import Image +from PIL import ImageTk +import pickle # Save Data +# Other Modules +import subprocess # Run python file +# Pathfinding +import pathlib +import sys +import os +from collections import defaultdict +# Used for live text displaying +import queue +import threading # Run the algorithm inside a thread + +import inference_v2 +import inference_v4 + +# Change the current working directory to the directory +# this file sits in +if getattr(sys, 'frozen', False): + # If the application is run as a bundle, the PyInstaller bootloader + # extends the sys module by a flag frozen=True and sets the app + # path into variable _MEIPASS'. + base_path = sys._MEIPASS +else: + base_path = os.path.dirname(os.path.abspath(__file__)) +os.chdir(base_path) # Change the current working directory to the base path + +instrumentalModels_dir = os.path.join(base_path, 'models') +stackedModels_dir = os.path.join(base_path, 'models') +logo_path = os.path.join(base_path, 'img', 'UVR-logo.png') +refresh_path = os.path.join(base_path, 'img', 'refresh.png') +DEFAULT_DATA = { + 'export_path': '', + 'gpu': False, + 'postprocess': False, + 'tta': False, + 'output_image': False, + 'sr': 44100, + 'hop_length': 1024, + 'window_size': 512, + 'n_fft': 2048, + 'stack': False, + 'stackPasses': 1, + 'stackOnly': False, + 'saveAllStacked': False, + 'modelFolder': False, + 'aiModel': 'v4', + + 'useModel': 'instrumental', + 'lastDir': None, +} + + +def open_image(path: str, size: tuple = None, keep_aspect: bool = True, rotate: int = 0) -> ImageTk.PhotoImage: + """ + Open the image on the path and apply given settings\n + Paramaters: + path(str): + Absolute path of the image + size(tuple): + first value - width + second value - height + keep_aspect(bool): + keep aspect ratio of image and resize + to maximum possible width and height + (maxima are given by size) + rotate(int): + clockwise rotation of image + Returns(ImageTk.PhotoImage): + Image of path + """ + img = Image.open(path).convert(mode='RGBA') + ratio = img.height/img.width + img = img.rotate(angle=-rotate) + if size is not None: + size = (int(size[0]), int(size[1])) + if keep_aspect: + img = img.resize((size[0], int(size[0] * ratio)), Image.ANTIALIAS) + else: + img = img.resize(size, Image.ANTIALIAS) + return ImageTk.PhotoImage(img) + + +def save_data(data): + """ + Saves given data as a .pkl (pickle) file + + Paramters: + data(dict): + Dictionary containing all the necessary data to save + """ + # Open data file, create it if it does not exist + with open('data.pkl', 'wb') as data_file: + pickle.dump(data, data_file) + + +def load_data() -> dict: + """ + Loads saved pkl file and returns the stored data + + Returns(dict): + Dictionary containing all the saved data + """ + try: + with open('data.pkl', 'rb') as data_file: # Open data file + data = pickle.load(data_file) + + return data + except (ValueError, FileNotFoundError): + # Data File is corrupted or not found so recreate it + save_data(data=DEFAULT_DATA) + + return load_data() + + +def get_model_values(model_name): + text = model_name.replace('.pth', '') + text_parts = text.split('_')[1:] + model_values = {} + + for text_part in text_parts: + if 'sr' in text_part: + text_part = text_part.replace('sr', '') + if text_part.isdecimal(): + try: + model_values['sr'] = int(text_part) + continue + except ValueError: + # Cannot convert string to int + pass + if 'hl' in text_part: + text_part = text_part.replace('hl', '') + if text_part.isdecimal(): + try: + model_values['hop_length'] = int(text_part) + continue + except ValueError: + # Cannot convert string to int + pass + if 'w' in text_part: + text_part = text_part.replace('w', '') + if text_part.isdecimal(): + try: + model_values['window_size'] = int(text_part) + continue + except ValueError: + # Cannot convert string to int + pass + if 'nf' in text_part: + text_part = text_part.replace('nf', '') + if text_part.isdecimal(): + try: + model_values['n_fft'] = int(text_part) + continue + except ValueError: + # Cannot convert string to int + pass + + return model_values + + +class ThreadSafeConsole(tk.Text): + """ + Text Widget which is thread safe for tkinter + """ + + def __init__(self, master, **options): + tk.Text.__init__(self, master, **options) + self.queue = queue.Queue() + self.update_me() + + def write(self, line): + self.queue.put(line) + + def clear(self): + self.queue.put(None) + + def update_me(self): + self.configure(state=tk.NORMAL) + try: + while 1: + line = self.queue.get_nowait() + if line is None: + self.delete(1.0, tk.END) + else: + self.insert(tk.END, str(line)) + self.see(tk.END) + self.update_idletasks() + except queue.Empty: + pass + self.configure(state=tk.DISABLED) + self.after(100, self.update_me) + + +class MainWindow(tk.Tk): + # --Constants-- + # Layout + IMAGE_HEIGHT = 140 + FILEPATHS_HEIGHT = 90 + OPTIONS_HEIGHT = 240 + CONVERSIONBUTTON_HEIGHT = 35 + COMMAND_HEIGHT = 200 + PROGRESS_HEIGHT = 26 + PADDING = 10 + + COL1_ROWS = 8 + COL2_ROWS = 7 + COL3_ROWS = 5 + + def __init__(self): + # Run the __init__ method on the tk.Tk class + super().__init__() + # Calculate window height + height = self.IMAGE_HEIGHT + self.FILEPATHS_HEIGHT + self.OPTIONS_HEIGHT + height += self.CONVERSIONBUTTON_HEIGHT + self.COMMAND_HEIGHT + self.PROGRESS_HEIGHT + height += self.PADDING * 5 # Padding + + # --Window Settings-- + self.title('Vocal Remover') + # Set Geometry and Center Window + self.geometry('{width}x{height}+{xpad}+{ypad}'.format( + width=590, + height=height, + xpad=int(self.winfo_screenwidth()/2 - 550/2), + ypad=int(self.winfo_screenheight()/2 - height/2 - 30))) + self.configure(bg='#000000') # Set background color to black + self.resizable(False, False) + self.update() + + # --Variables-- + self.logo_img = open_image(path=logo_path, + size=(self.winfo_width(), 9999)) + self.refresh_img = open_image(path=refresh_path, + size=(20, 20)) + self.instrumentalLabel_to_path = defaultdict(lambda: '') + self.stackedLabel_to_path = defaultdict(lambda: '') + self.lastInstrumentalModels = [] + self.lastStackedModels = [] + # -Tkinter Value Holders- + data = load_data() + # Paths + self.exportPath_var = tk.StringVar(value=data['export_path']) + self.inputPaths = [] + # Processing Options + self.gpuConversion_var = tk.BooleanVar(value=data['gpu']) + self.postprocessing_var = tk.BooleanVar(value=data['postprocess']) + self.tta_var = tk.BooleanVar(value=data['tta']) + self.outputImage_var = tk.BooleanVar(value=data['output_image']) + # Models + self.instrumentalModel_var = tk.StringVar(value='') + self.stackedModel_var = tk.StringVar(value='') + # Stacked Options + self.stack_var = tk.BooleanVar(value=data['stack']) + self.stackLoops_var = tk.StringVar(value=data['stackPasses']) + self.stackOnly_var = tk.BooleanVar(value=data['stackOnly']) + self.saveAllStacked_var = tk.BooleanVar(value=data['saveAllStacked']) + self.modelFolder_var = tk.BooleanVar(value=data['modelFolder']) + # Constants + self.srValue_var = tk.StringVar(value=data['sr']) + self.hopValue_var = tk.StringVar(value=data['hop_length']) + self.winSize_var = tk.StringVar(value=data['window_size']) + self.nfft_var = tk.StringVar(value=data['n_fft']) + # AI model + self.aiModel_var = tk.StringVar(value=data['aiModel']) + self.last_aiModel = self.aiModel_var.get() + # Other + self.lastDir = data['lastDir'] # nopep8 + self.progress_var = tk.IntVar(value=0) + # Font + self.font = tk.font.Font(family='Helvetica', size=9, weight='bold') + # --Widgets-- + self.create_widgets() + self.configure_widgets() + self.place_widgets() + + self.update_available_models() + self.update_states() + self.update_loop() + + # -Widget Methods- + def create_widgets(self): + """Create window widgets""" + self.title_Label = tk.Label(master=self, bg='black', + image=self.logo_img, compound=tk.TOP) + self.filePaths_Frame = tk.Frame(master=self, bg='black') + self.fill_filePaths_Frame() + + self.options_Frame = tk.Frame(master=self, bg='black') + self.fill_options_Frame() + + self.conversion_Button = ttk.Button(master=self, + text='Start Conversion', + command=self.start_conversion) + self.refresh_Button = ttk.Button(master=self, + image=self.refresh_img, + command=self.restart) + + self.progressbar = ttk.Progressbar(master=self, + variable=self.progress_var) + + self.command_Text = ThreadSafeConsole(master=self, + background='#a0a0a0', + borderwidth=0,) + self.command_Text.write(f'COMMAND LINE [{datetime.now().strftime("%Y-%m-%d %H:%M:%S")}]') # nopep8 + + def configure_widgets(self): + """Change widget styling and appearance""" + + ttk.Style().configure('TCheckbutton', background='black', + font=self.font, foreground='white') + ttk.Style().configure('TRadiobutton', background='black', + font=self.font, foreground='white') + ttk.Style().configure('T', font=self.font, foreground='white') + + def place_widgets(self): + """Place main widgets""" + self.title_Label.place(x=-2, y=-2) + self.filePaths_Frame.place(x=10, y=self.IMAGE_HEIGHT, width=-20, height=self.FILEPATHS_HEIGHT, + relx=0, rely=0, relwidth=1, relheight=0) + self.options_Frame.place(x=25, y=self.IMAGE_HEIGHT + self.FILEPATHS_HEIGHT + self.PADDING, width=-50, height=self.OPTIONS_HEIGHT, + relx=0, rely=0, relwidth=1, relheight=0) + self.conversion_Button.place(x=10, y=self.IMAGE_HEIGHT + self.FILEPATHS_HEIGHT + self.OPTIONS_HEIGHT + self.PADDING*2, width=-20 - 40, height=self.CONVERSIONBUTTON_HEIGHT, + relx=0, rely=0, relwidth=1, relheight=0) + self.refresh_Button.place(x=-10 - 35, y=self.IMAGE_HEIGHT + self.FILEPATHS_HEIGHT + self.OPTIONS_HEIGHT + self.PADDING*2, width=35, height=self.CONVERSIONBUTTON_HEIGHT, + relx=1, rely=0, relwidth=0, relheight=0) + self.command_Text.place(x=15, y=self.IMAGE_HEIGHT + self.FILEPATHS_HEIGHT + self.OPTIONS_HEIGHT + self.CONVERSIONBUTTON_HEIGHT + self.PADDING*3, width=-30, height=self.COMMAND_HEIGHT, + relx=0, rely=0, relwidth=1, relheight=0) + self.progressbar.place(x=25, y=self.IMAGE_HEIGHT + self.FILEPATHS_HEIGHT + self.OPTIONS_HEIGHT + self.CONVERSIONBUTTON_HEIGHT + self.COMMAND_HEIGHT + self.PADDING*4, width=-50, height=self.PROGRESS_HEIGHT, + relx=0, rely=0, relwidth=1, relheight=0) + + def fill_filePaths_Frame(self): + """Fill Frame with neccessary widgets""" + # -Create Widgets- + # Save To Option + self.filePaths_saveTo_Button = ttk.Button(master=self.filePaths_Frame, + text='Save to', + command=self.open_export_filedialog) + self.filePaths_saveTo_Entry = ttk.Entry(master=self.filePaths_Frame, + + textvariable=self.exportPath_var, + state=tk.DISABLED + ) + # Select Music Files Option + self.filePaths_musicFile_Button = ttk.Button(master=self.filePaths_Frame, + text='Select Your Audio File(s)', + command=self.open_file_filedialog) + self.filePaths_musicFile_Entry = ttk.Entry(master=self.filePaths_Frame, + text=self.inputPaths, + state=tk.DISABLED + ) + # -Place Widgets- + # Save To Option + self.filePaths_saveTo_Button.place(x=0, y=5, width=0, height=-10, + relx=0, rely=0, relwidth=0.3, relheight=0.5) + self.filePaths_saveTo_Entry.place(x=10, y=7, width=-20, height=-14, + relx=0.3, rely=0, relwidth=0.7, relheight=0.5) + # Select Music Files Option + self.filePaths_musicFile_Button.place(x=0, y=5, width=0, height=-10, + relx=0, rely=0.5, relwidth=0.4, relheight=0.5) + self.filePaths_musicFile_Entry.place(x=10, y=7, width=-20, height=-14, + relx=0.4, rely=0.5, relwidth=0.6, relheight=0.5) + + def fill_options_Frame(self): + """Fill Frame with neccessary widgets""" + # -Create Widgets- + # -Column 1- + # GPU Selection + self.options_gpu_Checkbutton = ttk.Checkbutton(master=self.options_Frame, + text='GPU Conversion', + variable=self.gpuConversion_var, + ) + # Postprocessing + self.options_post_Checkbutton = ttk.Checkbutton(master=self.options_Frame, + text='Post-Process', + variable=self.postprocessing_var, + ) + # TTA + self.options_tta_Checkbutton = ttk.Checkbutton(master=self.options_Frame, + text='TTA', + variable=self.tta_var, + ) + # Save Image + self.options_image_Checkbutton = ttk.Checkbutton(master=self.options_Frame, + text='Output Image', + variable=self.outputImage_var, + ) + # Stack Loops + self.options_stack_Checkbutton = ttk.Checkbutton(master=self.options_Frame, + text='Stack Passes', + variable=self.stack_var, + ) + self.options_stack_Entry = ttk.Entry(master=self.options_Frame, + textvariable=self.stackLoops_var,) + # Stack Only + self.options_stackOnly_Checkbutton = ttk.Checkbutton(master=self.options_Frame, + text='Stack Conversion Only', + variable=self.stackOnly_var, + ) + # Save All Stacked Outputs + self.options_saveStack_Checkbutton = ttk.Checkbutton(master=self.options_Frame, + text='Save All Stacked Outputs', + variable=self.saveAllStacked_var, + ) + self.options_modelFolder_Checkbutton = ttk.Checkbutton(master=self.options_Frame, + text='Model Test Mode', + variable=self.modelFolder_var, + ) + # -Column 2- + # SR + self.options_sr_Entry = ttk.Entry(master=self.options_Frame, + textvariable=self.srValue_var,) + self.options_sr_Label = tk.Label(master=self.options_Frame, + text='SR', anchor=tk.W, + background='#63605f', font=self.font, foreground='white', relief="sunken") + # HOP LENGTH + self.options_hop_Entry = ttk.Entry(master=self.options_Frame, + textvariable=self.hopValue_var,) + self.options_hop_Label = tk.Label(master=self.options_Frame, + text='HOP LENGTH', anchor=tk.W, + background='#63605f', font=self.font, foreground='white', relief="sunken") + # WINDOW SIZE + self.options_winSize_Entry = ttk.Entry(master=self.options_Frame, + textvariable=self.winSize_var,) + self.options_winSize_Label = tk.Label(master=self.options_Frame, + text='WINDOW SIZE', anchor=tk.W, + background='#63605f', font=self.font, foreground='white', relief="sunken") + # N_FFT + self.options_nfft_Entry = ttk.Entry(master=self.options_Frame, + textvariable=self.nfft_var,) + self.options_nfft_Label = tk.Label(master=self.options_Frame, + text='N_FFT', anchor=tk.W, + background='#63605f', font=self.font, foreground='white', relief="sunken") + # AI model + self.options_aiModel_Label = tk.Label(master=self.options_Frame, + text='Choose AI Engine', anchor=tk.CENTER, + background='#63605f', font=self.font, foreground='white', relief="sunken") + self.options_aiModel_Optionmenu = ttk.OptionMenu(self.options_Frame, + self.aiModel_var, + None, 'v2', 'v4',) + # "Save to", "Select Your Audio File(s)"", and "Start Conversion" Button Style + s = ttk.Style() + s.configure('TButton', background='blue', foreground='black', font=('Verdana', '9', 'bold'), relief="sunken") + + # -Column 3- + # Choose Instrumental Model + self.options_instrumentalModel_Label = tk.Label(master=self.options_Frame, + text='Choose Instrumental Model', + background='#a7a7a7', font=self.font, relief="ridge") + self.options_instrumentalModel_Optionmenu = ttk.OptionMenu(self.options_Frame, + self.instrumentalModel_var) + # Choose Stacked Model + self.options_stackedModel_Label = tk.Label(master=self.options_Frame, + text='Choose Stacked Model', + background='#a7a7a7', font=self.font, relief="ridge") + self.options_stackedModel_Optionmenu = ttk.OptionMenu(self.options_Frame, + self.stackedModel_var,) + self.options_model_Button = ttk.Button(master=self.options_Frame, + text='Add New Model(s)', + style="Bold.TButton", + command=self.open_newModel_filedialog) + # -Place Widgets- + # -Column 1- + self.options_gpu_Checkbutton.place(x=0, y=0, width=0, height=0, + relx=0, rely=0, relwidth=1/3, relheight=1/self.COL1_ROWS) + self.options_post_Checkbutton.place(x=0, y=0, width=0, height=0, + relx=0, rely=1/self.COL1_ROWS, relwidth=1/3, relheight=1/self.COL1_ROWS) + self.options_tta_Checkbutton.place(x=0, y=0, width=0, height=0, + relx=0, rely=2/self.COL1_ROWS, relwidth=1/3, relheight=1/self.COL1_ROWS) + self.options_image_Checkbutton.place(x=0, y=0, width=0, height=0, + relx=0, rely=3/self.COL1_ROWS, relwidth=1/3, relheight=1/self.COL1_ROWS) + # Stacks + self.options_stack_Checkbutton.place(x=0, y=0, width=0, height=0, + relx=0, rely=4/self.COL1_ROWS, relwidth=1/3/4*3, relheight=1/self.COL1_ROWS) + self.options_stack_Entry.place(x=0, y=3, width=0, height=-6, + relx=1/3/4*2.4, rely=4/self.COL1_ROWS, relwidth=1/3/4*0.9, relheight=1/self.COL1_ROWS) + self.options_stackOnly_Checkbutton.place(x=0, y=0, width=0, height=0, + relx=0, rely=5/self.COL1_ROWS, relwidth=1/3, relheight=1/self.COL1_ROWS) + self.options_saveStack_Checkbutton.place(x=0, y=0, width=0, height=0, + relx=0, rely=6/self.COL1_ROWS, relwidth=1/3, relheight=1/self.COL1_ROWS) + # Model Folder + self.options_modelFolder_Checkbutton.place(x=0, y=0, width=0, height=0, + relx=0, rely=7/self.COL1_ROWS, relwidth=1/3, relheight=1/self.COL1_ROWS) + # -Column 2- + # SR + self.options_sr_Label.place(x=5, y=4, width=5, height=-8, + relx=1/3, rely=0, relwidth=1/3/2, relheight=1/self.COL2_ROWS) + self.options_sr_Entry.place(x=15, y=4, width=5, height=-8, + relx=1/3 + 1/3/2, rely=0, relwidth=1/3/4, relheight=1/self.COL2_ROWS) + # HOP LENGTH + self.options_hop_Label.place(x=5, y=4, width=5, height=-8, + relx=1/3, rely=1/self.COL2_ROWS, relwidth=1/3/2, relheight=1/self.COL2_ROWS) + self.options_hop_Entry.place(x=15, y=4, width=5, height=-8, + relx=1/3 + 1/3/2, rely=1/self.COL2_ROWS, relwidth=1/3/4, relheight=1/self.COL2_ROWS) + # WINDOW SIZE + self.options_winSize_Label.place(x=5, y=4, width=5, height=-8, + relx=1/3, rely=2/self.COL2_ROWS, relwidth=1/3/2, relheight=1/self.COL2_ROWS) + self.options_winSize_Entry.place(x=15, y=4, width=5, height=-8, + relx=1/3 + 1/3/2, rely=2/self.COL2_ROWS, relwidth=1/3/4, relheight=1/self.COL2_ROWS) + # N_FFT + self.options_nfft_Label.place(x=5, y=4, width=5, height=-8, + relx=1/3, rely=3/self.COL2_ROWS, relwidth=1/3/2, relheight=1/self.COL2_ROWS) + self.options_nfft_Entry.place(x=15, y=4, width=5, height=-8, + relx=1/3 + 1/3/2, rely=3/self.COL2_ROWS, relwidth=1/3/4, relheight=1/self.COL2_ROWS) + # AI model + self.options_aiModel_Label.place(x=5, y=-5, width=-30, height=-8, + relx=1/3, rely=5/self.COL2_ROWS, relwidth=1/3, relheight=1/self.COL2_ROWS) + self.options_aiModel_Optionmenu.place(x=5, y=-5, width=-30, height=-8, + relx=1/3, rely=6/self.COL2_ROWS, relwidth=1/3, relheight=1/self.COL2_ROWS) + + # -Column 3- + # Choose Model + self.options_instrumentalModel_Label.place(x=0, y=0, width=0, height=-10, + relx=2/3, rely=0, relwidth=1/3, relheight=1/self.COL3_ROWS) + self.options_instrumentalModel_Optionmenu.place(x=15, y=-4, width=-30, height=-13, + relx=2/3, rely=1/self.COL3_ROWS, relwidth=1/3, relheight=1/self.COL3_ROWS) + self.options_stackedModel_Label.place(x=0, y=0, width=0, height=-10, + relx=2/3, rely=2/self.COL3_ROWS, relwidth=1/3, relheight=1/self.COL3_ROWS) + self.options_stackedModel_Optionmenu.place(x=15, y=-4, width=-30, height=-13, + relx=2/3, rely=3/self.COL3_ROWS, relwidth=1/3, relheight=1/self.COL3_ROWS) + self.options_model_Button.place(x=15, y=3, width=-30, height=-8, + relx=2/3, rely=4/self.COL3_ROWS, relwidth=1/3, relheight=1/self.COL3_ROWS) + + # -Update Binds- + self.options_stackOnly_Checkbutton.configure(command=self.update_states) # nopep8 + self.options_stack_Checkbutton.configure(command=self.update_states) # nopep8 + self.options_stack_Entry.bind('', + lambda e: self.update_states()) + # Model name decoding + self.instrumentalModel_var.trace_add('write', + lambda *args: self.decode_modelNames()) + self.stackedModel_var.trace_add('write', + lambda *args: self.decode_modelNames()) + # Model deselect + self.aiModel_var.trace_add('write', + lambda *args: self.deselect_models()) + + # Opening filedialogs + def open_file_filedialog(self): + """Make user select music files""" + if self.lastDir is not None: + if not os.path.isdir(self.lastDir): + self.lastDir = None + + paths = tk.filedialog.askopenfilenames( + parent=self, + title=f'Select Music Files', + initialfile='', + initialdir=self.lastDir, + ) + if paths: # Path selected + self.inputPaths = paths + # Change the entry text + self.filePaths_musicFile_Entry.configure(state=tk.NORMAL) + self.filePaths_musicFile_Entry.delete(0, tk.END) + self.filePaths_musicFile_Entry.insert(0, self.inputPaths) + self.filePaths_musicFile_Entry.configure(state=tk.DISABLED) + + self.lastDir = os.path.dirname(paths[0]) + + def open_export_filedialog(self): + """Make user select a folder to export the converted files in""" + path = tk.filedialog.askdirectory( + parent=self, + title=f'Select Folder',) + if path: # Path selected + self.exportPath_var.set(path) + + def open_newModel_filedialog(self): + """Let user paste a ".pth" model to use for the vocal seperation""" + os.startfile('models') + + def start_conversion(self): + """ + Start the conversion for all the given mp3 and wav files + """ + # -Get all variables- + export_path = self.exportPath_var.get() + instrumentalModel_path = self.instrumentalLabel_to_path[self.instrumentalModel_var.get()] # nopep8 + stackedModel_path = self.stackedLabel_to_path[self.stackedModel_var.get()] # nopep8 + # Get constants + instrumental = get_model_values(self.instrumentalModel_var.get()) + stacked = get_model_values(self.stackedModel_var.get()) + try: + if [bool(instrumental), bool(stacked)].count(True) == 2: + sr = DEFAULT_DATA['sr'] + hop_length = DEFAULT_DATA['hop_length'] + window_size = DEFAULT_DATA['window_size'] + n_fft = DEFAULT_DATA['n_fft'] + else: + sr = int(self.srValue_var.get()) + hop_length = int(self.hopValue_var.get()) + window_size = int(self.winSize_var.get()) + n_fft = int(self.nfft_var.get()) + stackPasses = int(self.stackLoops_var.get()) + except ValueError: # Non integer was put in entry box + tk.messagebox.showwarning(master=self, + title='Invalid Input', + message='Please make sure you only input integer numbers!') + return + except SyntaxError: # Non integer was put in entry box + tk.messagebox.showwarning(master=self, + title='Invalid Music File', + message='You have selected an invalid music file!\nPlease make sure that your files still exist and ends with either ".mp3", ".mp4", ".m4a", ".flac", ".wav"') + return + + # -Check for invalid inputs- + if not any([(os.path.isfile(path) and path.endswith(('.mp3', '.mp4', '.m4a', '.flac', '.wav'))) + for path in self.inputPaths]): + tk.messagebox.showwarning(master=self, + title='Invalid Music File', + message='You have selected an invalid music file!\nPlease make sure that your files still exist and ends with either ".mp3", ".mp4", ".m4a", ".flac", ".wav"') + return + if not os.path.isdir(export_path): + tk.messagebox.showwarning(master=self, + title='Invalid Export Directory', + message='You have selected an invalid export directory!\nPlease make sure that your directory still exists!') + return + if not self.stackOnly_var.get(): + if not os.path.isfile(instrumentalModel_path): + tk.messagebox.showwarning(master=self, + title='Invalid Instrumental Model File', + message='You have selected an invalid instrumental model file!\nPlease make sure that your model file still exists!') + return + if (self.stackOnly_var.get() or + stackPasses > 0): + if not os.path.isfile(stackedModel_path): + tk.messagebox.showwarning(master=self, + title='Invalid Stacked Model File', + message='You have selected an invalid stacked model file!\nPlease make sure that your model file still exists!') + return + + # -Save Data- + save_data(data={ + 'export_path': export_path, + 'gpu': self.gpuConversion_var.get(), + 'postprocess': self.postprocessing_var.get(), + 'tta': self.tta_var.get(), + 'output_image': self.outputImage_var.get(), + 'stack': self.stack_var.get(), + 'stackOnly': self.stackOnly_var.get(), + 'stackPasses': stackPasses, + 'saveAllStacked': self.saveAllStacked_var.get(), + 'sr': sr, + 'hop_length': hop_length, + 'window_size': window_size, + 'n_fft': n_fft, + 'useModel': 'instrumental', # Always instrumental + 'lastDir': self.lastDir, + 'modelFolder': self.modelFolder_var.get(), + 'aiModel': self.aiModel_var.get(), + }) + + if self.aiModel_var.get() == 'v2': + inference = inference_v2 + elif self.aiModel_var.get() == 'v4': + inference = inference_v4 + else: + raise TypeError('This error should not occur.') + + # -Run the algorithm- + threading.Thread(target=inference.main, + kwargs={ + # Paths + 'input_paths': self.inputPaths, + 'export_path': export_path, + # Processing Options + 'gpu': 0 if self.gpuConversion_var.get() else -1, + 'postprocess': self.postprocessing_var.get(), + 'tta': self.tta_var.get(), # not needed for v2 + 'output_image': self.outputImage_var.get(), + # Models + 'instrumentalModel': instrumentalModel_path, + 'vocalModel': '', # Always not needed + 'stackModel': stackedModel_path, + 'useModel': 'instrumental', # Always instrumental + # Stack Options + 'stackPasses': stackPasses, + 'stackOnly': self.stackOnly_var.get(), + 'saveAllStacked': self.saveAllStacked_var.get(), + # Model Folder + 'modelFolder': self.modelFolder_var.get(), + # Constants + 'sr': sr, + 'hop_length': hop_length, + 'window_size': window_size, + 'n_fft': n_fft, # not needed for v2 + # Other Variables (Tkinter) + 'window': self, + 'text_widget': self.command_Text, + 'button_widget': self.conversion_Button, + 'progress_var': self.progress_var, + }, + daemon=True + ).start() + + # Models + def decode_modelNames(self): + """ + Enable/Disable the 4 constants based on the selected model names + """ + # Check state of model selectors + instrumental_selectable = bool(str(self.options_instrumentalModel_Optionmenu.cget('state')) == 'normal') + stacked_selectable = bool(str(self.options_stackedModel_Optionmenu.cget('state')) == 'normal') + + # Extract data from models name + instrumental = get_model_values(self.instrumentalModel_var.get()) + stacked = get_model_values(self.stackedModel_var.get()) + + # Assign widgets to constants + widgetsVars = { + 'sr': [self.options_sr_Entry, self.srValue_var], + 'hop_length': [self.options_hop_Entry, self.hopValue_var], + 'window_size': [self.options_winSize_Entry, self.winSize_var], + 'n_fft': [self.options_nfft_Entry, self.nfft_var], + } + + # Loop through each constant (key) and its widgets + for key, (widget, var) in widgetsVars.items(): + if stacked_selectable: + # Stacked model can be selected + if key in stacked.keys(): + if (key in stacked.keys() and + not instrumental_selectable): + # Only stacked selectable + widget.configure(state=tk.DISABLED) + var.set(stacked[key]) + continue + elif (key in instrumental.keys() and + instrumental_selectable): + # Both models have set constants + widget.configure(state=tk.DISABLED) + var.set('%d/%d' % (instrumental[key], stacked[key])) + continue + else: + # Stacked model can not be selected + if (key in instrumental.keys() and + instrumental_selectable): + widget.configure(state=tk.DISABLED) + var.set(instrumental[key]) + continue + + # If widget is already enabled, no need to reset the value + if str(widget.cget('state')) != 'normal': + widget.configure(state=tk.NORMAL) + var.set(DEFAULT_DATA[key]) + + def update_loop(self): + """Update the dropdown menu""" + self.update_available_models() + + self.after(3000, self.update_loop) + + def update_available_models(self): + """ + Loop through every model (.pth) in the models directory + and add to the select your model list + """ + temp_instrumentalModels_dir = os.path.join(instrumentalModels_dir, self.aiModel_var.get(), 'Instrumental Models') # nopep8 + temp_stackedModels_dir = os.path.join(stackedModels_dir, self.aiModel_var.get(), 'Stacked Models') + # Instrumental models + new_InstrumentalModels = os.listdir(temp_instrumentalModels_dir) + if new_InstrumentalModels != self.lastInstrumentalModels: + self.instrumentalLabel_to_path.clear() + self.options_instrumentalModel_Optionmenu['menu'].delete(0, 'end') + for file_name in new_InstrumentalModels: + if file_name.endswith('.pth'): + # Add Radiobutton to the Options Menu + self.options_instrumentalModel_Optionmenu['menu'].add_radiobutton(label=file_name, + command=tk._setit(self.instrumentalModel_var, file_name)) + # Link the files name to its absolute path + self.instrumentalLabel_to_path[file_name] = os.path.join(temp_instrumentalModels_dir, file_name) # nopep8 + self.lastInstrumentalModels = new_InstrumentalModels + # Stacked models + new_stackedModels = os.listdir(temp_stackedModels_dir) + if new_stackedModels != self.lastStackedModels: + self.stackedLabel_to_path.clear() + self.options_stackedModel_Optionmenu['menu'].delete(0, 'end') + for file_name in new_stackedModels: + if file_name.endswith('.pth'): + # Add Radiobutton to the Options Menu + self.options_stackedModel_Optionmenu['menu'].add_radiobutton(label=file_name, + command=tk._setit(self.stackedModel_var, file_name)) + # Link the files name to its absolute path + self.stackedLabel_to_path[file_name] = os.path.join(temp_stackedModels_dir, file_name) # nopep8 + self.lastStackedModels = new_stackedModels + + def update_states(self): + """ + Vary the states for all widgets based + on certain selections + """ + try: + stackLoops = int(self.stackLoops_var.get()) + except ValueError: + stackLoops = 0 + + # Stack Passes + if self.stack_var.get(): + self.options_stack_Entry.configure(state=tk.NORMAL) + if stackLoops <= 0: + self.stackLoops_var.set(1) + stackLoops = 1 + else: + self.options_stack_Entry.configure(state=tk.DISABLED) + self.stackLoops_var.set(0) + stackLoops = 0 + + # Stack Only and Save All Outputs + if stackLoops > 0: + self.options_stackOnly_Checkbutton.configure(state=tk.NORMAL) + self.options_saveStack_Checkbutton.configure(state=tk.NORMAL) + else: + self.options_stackOnly_Checkbutton.configure(state=tk.DISABLED) + self.options_saveStack_Checkbutton.configure(state=tk.DISABLED) + self.saveAllStacked_var.set(False) + self.stackOnly_var.set(False) + + # Models + if self.stackOnly_var.get(): + # Instrumental Model + self.options_instrumentalModel_Label.configure(foreground='#777') + self.options_instrumentalModel_Optionmenu.configure(state=tk.DISABLED) # nopep8 + self.instrumentalModel_var.set('') + # Stack Model + self.options_stackedModel_Label.configure(foreground='#000') + self.options_stackedModel_Optionmenu.configure(state=tk.NORMAL) # nopep8 + else: + # Instrumental Model + self.options_instrumentalModel_Label.configure(foreground='#000') + self.options_instrumentalModel_Optionmenu.configure(state=tk.NORMAL) # nopep8 + self.instrumentalModel_var.set('') + + # Stack Model + if stackLoops > 0: + self.options_stackedModel_Label.configure(foreground='#000') + self.options_stackedModel_Optionmenu.configure(state=tk.NORMAL) # nopep8 + else: + self.options_stackedModel_Label.configure(foreground='#777') + self.options_stackedModel_Optionmenu.configure(state=tk.DISABLED) # nopep8 + self.stackedModel_var.set('') + + if self.aiModel_var.get() == 'v2': + self.options_tta_Checkbutton.configure(state=tk.DISABLED) + self.options_nfft_Label.place_forget() + self.options_nfft_Entry.place_forget() + else: + self.options_tta_Checkbutton.configure(state=tk.NORMAL) + self.options_nfft_Label.place(x=5, y=4, width=5, height=-8, + relx=1/3, rely=3/self.COL2_ROWS, relwidth=1/3/2, relheight=1/self.COL2_ROWS) + self.options_nfft_Entry.place(x=15, y=4, width=5, height=-8, + relx=1/3 + 1/3/2, rely=3/self.COL2_ROWS, relwidth=1/3/4, relheight=1/self.COL2_ROWS) + + self.decode_modelNames() + + def deselect_models(self): + """ + Run this method on version change + """ + if self.aiModel_var.get() == self.last_aiModel: + return + else: + self.last_aiModel = self.aiModel_var.get() + + self.instrumentalModel_var.set('') + self.stackedModel_var.set('') + + self.srValue_var.set(DEFAULT_DATA['sr']) + self.hopValue_var.set(DEFAULT_DATA['hop_length']) + self.winSize_var.set(DEFAULT_DATA['window_size']) + self.nfft_var.set(DEFAULT_DATA['n_fft']) + + self.update_available_models() + self.update_states() + + def restart(self): + """ + Restart the application after asking for confirmation + """ + proceed = tk.messagebox.askyesno(title='Confirmation', + message='The application will restart and lose unsaved data. Do you wish to proceed?') + if proceed: + subprocess.Popen(f'python "{__file__}"', shell=True) + exit() + + +if __name__ == "__main__": + root = MainWindow() + + root.mainloop() diff --git a/inference_v2.py b/inference_v2.py new file mode 100644 index 0000000..b530adf --- /dev/null +++ b/inference_v2.py @@ -0,0 +1,489 @@ +import argparse +import os + +import cv2 +import librosa +import numpy as np +import soundfile as sf +from tqdm import tqdm + +from lib_v2 import dataset +from lib_v2 import nets +from lib_v2 import spec_utils + +import torch +# Variable manipulation and command line text parsing +from collections import defaultdict +import tkinter as tk +import time # Timer +import traceback # Error Message Recent Calls + + +class Namespace: + """ + Replaces ArgumentParser + """ + + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + + +data = { + # Paths + 'input_paths': None, + 'export_path': None, + # Processing Options + 'gpu': -1, + 'postprocess': True, + 'tta': True, + 'output_image': True, + # Models + 'instrumentalModel': None, + 'vocalModel': None, + 'stackModel': None, + 'useModel': None, + # Stack Options + 'stackPasses': 0, + 'stackOnly': False, + 'saveAllStacked': False, + # Model Folder + 'modelFolder': False, + # Constants + 'sr': 44_100, + 'hop_length': 1_024, + 'window_size': 512, + 'n_fft': 2_048, +} +default_sr = data['sr'] +default_hop_length = data['hop_length'] +default_window_size = data['window_size'] +default_n_fft = data['n_fft'] + + +def update_progress(progress_var, total_files, total_loops, file_num, loop_num, step: float = 1): + """Calculate the progress for the progress widget in the GUI""" + base = (100 / total_files) + progress = base * (file_num - 1) + progress += (base / total_loops) * (loop_num + step) + + progress_var.set(progress) + + +def get_baseText(total_files, total_loops, file_num, loop_num): + """Create the base text for the command widget""" + text = 'File {file_num}/{total_files}:{loop} '.format(file_num=file_num, + total_files=total_files, + loop='' if total_loops <= 1 else f' ({loop_num+1}/{total_loops})') + return text + + +def update_constants(model_name): + """ + Decode the conversion settings from the model's name + """ + global data + text = model_name.replace('.pth', '') + text_parts = text.split('_')[1:] + + # First set everything to default -> + # If file name is not decodeable (invalid or no text_parts), constants stay at default + data['sr'] = default_sr + data['hop_length'] = default_hop_length + data['window_size'] = default_window_size + data['n_fft'] = default_n_fft + + for text_part in text_parts: + if 'sr' in text_part: + text_part = text_part.replace('sr', '') + if text_part.isdecimal(): + try: + data['sr'] = int(text_part) + continue + except ValueError: + # Cannot convert string to int + pass + if 'hl' in text_part: + text_part = text_part.replace('hl', '') + if text_part.isdecimal(): + try: + data['hop_length'] = int(text_part) + continue + except ValueError: + # Cannot convert string to int + pass + if 'w' in text_part: + text_part = text_part.replace('w', '') + if text_part.isdecimal(): + try: + data['window_size'] = int(text_part) + continue + except ValueError: + # Cannot convert string to int + pass + if 'nf' in text_part: + text_part = text_part.replace('nf', '') + if text_part.isdecimal(): + try: + data['n_fft'] = int(text_part) + continue + except ValueError: + # Cannot convert string to int + pass + + +def determineModelFolderName(): + """ + Determine the name that is used for the folder and appended + to the back of the music files + """ + modelFolderName = '' + if not data['modelFolder']: + # Model Test Mode not selected + return modelFolderName + + # -Instrumental- + if os.path.isfile(data['instrumentalModel']): + modelFolderName += os.path.splitext(os.path.basename(data['instrumentalModel']))[0] + '-' + # -Vocal- + elif os.path.isfile(data['vocalModel']): + modelFolderName += os.path.splitext(os.path.basename(data['vocalModel']))[0] + '-' + # -Stack- + if os.path.isfile(data['stackModel']): + modelFolderName += os.path.splitext(os.path.basename(data['stackModel']))[0] + else: + modelFolderName = modelFolderName[:-1] + + if modelFolderName: + modelFolderName = '/' + modelFolderName + + return modelFolderName + + +def main(window: tk.Wm, text_widget: tk.Text, button_widget: tk.Button, progress_var: tk.Variable, + **kwargs: dict): + def load_models(): + text_widget.write('Loading models...\n') # nopep8 Write Command Text + models = defaultdict(lambda: None) + devices = defaultdict(lambda: None) + + # -Instrumental- + if os.path.isfile(data['instrumentalModel']): + device = torch.device('cpu') + model = nets.CascadedASPPNet() + model.load_state_dict(torch.load(data['instrumentalModel'], + map_location=device)) + if torch.cuda.is_available() and data['gpu'] >= 0: + device = torch.device('cuda:{}'.format(data['gpu'])) + model.to(device) + + models['instrumental'] = model + devices['instrumental'] = device + # -Vocal- + elif os.path.isfile(data['vocalModel']): + device = torch.device('cpu') + model = nets.CascadedASPPNet() + model.load_state_dict(torch.load(data['vocalModel'], + map_location=device)) + if torch.cuda.is_available() and data['gpu'] >= 0: + device = torch.device('cuda:{}'.format(data['gpu'])) + model.to(device) + + models['vocal'] = model + devices['vocal'] = device + # -Stack- + if os.path.isfile(data['stackModel']): + device = torch.device('cpu') + model = nets.CascadedASPPNet() + model.load_state_dict(torch.load(data['stackModel'], + map_location=device)) + if torch.cuda.is_available() and data['gpu'] >= 0: + device = torch.device('cuda:{}'.format(data['gpu'])) + model.to(device) + + models['stack'] = model + devices['stack'] = device + + text_widget.write('Done!\n') + return models, devices + + def load_wave_source(): + X, sr = librosa.load(music_file, + data['sr'], + False, + dtype=np.float32, + res_type='kaiser_fast') + + return X, sr + + def stft_wave_source(X, model, device): + X = spec_utils.calc_spec(X, data['hop_length']) + X, phase = np.abs(X), np.exp(1.j * np.angle(X)) + coeff = X.max() + X /= coeff + + offset = model.offset + l, r, roi_size = dataset.make_padding( + X.shape[2], data['window_size'], offset) + X_pad = np.pad(X, ((0, 0), (0, 0), (l, r)), mode='constant') + X_roll = np.roll(X_pad, roi_size // 2, axis=2) + + model.eval() + with torch.no_grad(): + masks = [] + masks_roll = [] + length = int(np.ceil(X.shape[2] / roi_size)) + for i in tqdm(range(length)): + update_progress(**progress_kwargs, + step=0.1 + 0.5*(i/(length - 1))) + start = i * roi_size + X_window = torch.from_numpy(np.asarray([ + X_pad[:, :, start:start + data['window_size']], + X_roll[:, :, start:start + data['window_size']] + ])).to(device) + pred = model.predict(X_window) + pred = pred.detach().cpu().numpy() + masks.append(pred[0]) + masks_roll.append(pred[1]) + + mask = np.concatenate(masks, axis=2)[:, :, :X.shape[2]] + mask_roll = np.concatenate(masks_roll, axis=2)[ + :, :, :X.shape[2]] + mask = (mask + np.roll(mask_roll, -roi_size // 2, axis=2)) / 2 + + if data['postprocess']: + vocal = X * (1 - mask) * coeff + mask = spec_utils.mask_uninformative(mask, vocal) + + inst = X * mask * coeff + vocal = X * (1 - mask) * coeff + + return inst, vocal, phase, mask + + def invert_instrum_vocal(inst, vocal, phase): + wav_instrument = spec_utils.spec_to_wav(inst, phase, data['hop_length']) # nopep8 + wav_vocals = spec_utils.spec_to_wav(vocal, phase, data['hop_length']) # nopep8 + + return wav_instrument, wav_vocals + + def save_files(wav_instrument, wav_vocals): + """Save output music files""" + vocal_name = None + instrumental_name = None + folder = '' + + # Get the Suffix Name + if (not loop_num or + loop_num == (total_loops - 1)): # First or Last Loop + if data['stackOnly']: + if loop_num == (total_loops - 1): # Last Loop + if not (total_loops - 1): # Only 1 Loop + vocal_name = '(Vocals)' + instrumental_name = '(Instrumental)' + else: + vocal_name = '(Vocal_Final_Stacked_Output)' + instrumental_name = '(Instrumental_Final_Stacked_Output)' + elif data['useModel'] == 'instrumental': + if not loop_num: # First Loop + vocal_name = '(Vocals)' + if loop_num == (total_loops - 1): # Last Loop + if not (total_loops - 1): # Only 1 Loop + instrumental_name = '(Instrumental)' + else: + instrumental_name = '(Instrumental_Final_Stacked_Output)' + elif data['useModel'] == 'vocal': + if not loop_num: # First Loop + instrumental_name = '(Instrumental)' + if loop_num == (total_loops - 1): # Last Loop + if not (total_loops - 1): # Only 1 Loop + vocal_name = '(Vocals)' + else: + vocal_name = '(Vocals_Final_Stacked_Output)' + if data['useModel'] == 'vocal': + # Reverse names + vocal_name, instrumental_name = instrumental_name, vocal_name + elif data['saveAllStacked']: + folder = os.path.splitext(os.path.basename(base_name))[0] + ' Stacked Outputs' # nopep8 + folder = os.path.basename(folder) + '/' + folder_path = os.path.dirname(base_name) + folder_path = os.path.join(folder_path, folder) + + if not os.path.isdir(folder_path): + os.mkdir(folder_path) + + if data['stackOnly']: + vocal_name = f'(Vocal_{loop_num}_Stacked_Output)' + instrumental_name = f'(Instrumental_{loop_num}_Stacked_Output)' + elif (data['useModel'] == 'vocal' or + data['useModel'] == 'instrumental'): + vocal_name = f'(Vocals_{loop_num}_Stacked_Output)' + instrumental_name = f'(Instrumental_{loop_num}_Stacked_Output)' + + if data['useModel'] == 'vocal': + # Reverse names + vocal_name, instrumental_name = instrumental_name, vocal_name + + # Save Temp File + # For instrumental the instrumental is the temp file + # and for vocal the instrumental is the temp file due + # to reversement + sf.write(f'temp.wav', + wav_instrument.T, sr) + + appendModelFolderName = modelFolderName.replace('/', '_') + # -Save files- + # Instrumental + if instrumental_name is not None: + instrumental_path = '{base_path}/{folder}{file_name}.wav'.format( + base_path=os.path.dirname(base_name), + folder=folder, + file_name=f'{os.path.basename(base_name)}_{instrumental_name}{appendModelFolderName}', + ) + sf.write(instrumental_path, + wav_instrument.T, sr) + # Vocal + if vocal_name is not None: + vocal_path = '{base_path}/{folder}{file_name}.wav'.format( + base_path=os.path.dirname(base_name), + folder=folder, + file_name=f'{os.path.basename(base_name)}_{vocal_name}{appendModelFolderName}', + ) + sf.write(vocal_path, + wav_vocals.T, sr) + + def output_image(): + norm_mask = np.uint8((1 - mask) * 255).transpose(1, 2, 0) + norm_mask = np.concatenate([ + np.max(norm_mask, axis=2, keepdims=True), + norm_mask], axis=2)[::-1] + _, bin_mask = cv2.imencode('.png', norm_mask) + text_widget.write(base_text + 'Saving Mask...\n') # nopep8 Write Command Text + with open(f'{base_name}_(Mask).png', mode='wb') as f: + bin_mask.tofile(f) + + data.update(kwargs) + + # Update default settings + global default_sr + global default_hop_length + global default_window_size + global default_n_fft + default_sr = data['sr'] + default_hop_length = data['hop_length'] + default_window_size = data['window_size'] + default_n_fft = data['n_fft'] + + stime = time.perf_counter() + progress_var.set(0) + text_widget.clear() + button_widget.configure(state=tk.DISABLED) # Disable Button + + models, devices = load_models() + modelFolderName = determineModelFolderName() + if modelFolderName: + folder_path = f'{data["export_path"]}{modelFolderName}' + if not os.path.isdir(folder_path): + os.mkdir(folder_path) + + # Determine Loops + total_loops = data['stackPasses'] + if not data['stackOnly']: + total_loops += 1 + + for file_num, music_file in enumerate(data['input_paths'], start=1): + try: + # Determine File Name + base_name = f'{data["export_path"]}{modelFolderName}/{file_num}_{os.path.splitext(os.path.basename(music_file))[0]}' + + for loop_num in range(total_loops): + # -Determine which model will be used- + if not loop_num: + # First Iteration + if data['stackOnly']: + if os.path.isfile(data['stackModel']): + model_name = os.path.basename(data['stackModel']) + model = models['stack'] + device = devices['stack'] + else: + raise ValueError(f'Selected stack only model, however, stack model path file cannot be found\nPath: "{data["stackModel"]}"') # nopep8 + else: + model_name = os.path.basename(data[f'{data["useModel"]}Model']) + model = models[data['useModel']] + device = devices[data['useModel']] + else: + model_name = os.path.basename(data['stackModel']) + # Every other iteration + model = models['stack'] + device = devices['stack'] + # Reference new music file + music_file = 'temp.wav' + + # -Get text and update progress- + base_text = get_baseText(total_files=len(data['input_paths']), + total_loops=total_loops, + file_num=file_num, + loop_num=loop_num) + progress_kwargs = {'progress_var': progress_var, + 'total_files': len(data['input_paths']), + 'total_loops': total_loops, + 'file_num': file_num, + 'loop_num': loop_num} + update_progress(**progress_kwargs, + step=0) + update_constants(model_name) + + # -Go through the different steps of seperation- + # Wave source + text_widget.write(base_text + 'Loading wave source...\n') # nopep8 Write Command Text + X, sr = load_wave_source() + text_widget.write(base_text + 'Done!\n') # nopep8 Write Command Text + + update_progress(**progress_kwargs, + step=0.1) + # Stft of wave source + text_widget.write(base_text + 'Stft of wave source...\n') # nopep8 Write Command Text + inst, vocal, phase, mask = stft_wave_source(X, model, device) + text_widget.write(base_text + 'Done!\n') # nopep8 Write Command Text + + update_progress(**progress_kwargs, + step=0.6) + # Inverse stft + text_widget.write(base_text + 'Inverse stft of instruments and vocals...\n') # nopep8 Write Command Text + wav_instrument, wav_vocals = invert_instrum_vocal(inst, vocal, phase) # nopep8 + text_widget.write(base_text + 'Done!\n') # nopep8 Write Command Text + + update_progress(**progress_kwargs, + step=0.7) + # Save Files + text_widget.write(base_text + 'Saving Files...\n') # nopep8 Write Command Text + save_files(wav_instrument, wav_vocals) + text_widget.write(base_text + 'Done!\n') # nopep8 Write Command Text + + update_progress(**progress_kwargs, + step=0.8) + + else: + # Save Output Image (Mask) + if data['output_image']: + text_widget.write(base_text + 'Creating Mask...\n') # nopep8 Write Command Text + output_image() + text_widget.write(base_text + 'Done!\n') # nopep8 Write Command Text + + text_widget.write(base_text + 'Completed Seperation!\n\n') # nopep8 Write Command Text + except Exception as e: + traceback_text = ''.join(traceback.format_tb(e.__traceback__)) + message = f'Traceback Error: "{traceback_text}"\n{type(e).__name__}: "{e}"\nFile: {music_file}\nLoop: {loop_num}\nPlease contact the creator and attach a screenshot of this error with the file and settings that caused it!' + tk.messagebox.showerror(master=window, + title='Untracked Error', + message=message) + print(traceback_text) + print(type(e).__name__, e) + print(message) + progress_var.set(0) + button_widget.configure(state=tk.NORMAL) # Enable Button + return + + os.remove('temp.wav') + progress_var.set(0) # Update Progress + text_widget.write(f'Conversion(s) Completed and Saving all Files!\n') # nopep8 Write Command Text + text_widget.write(f'Time Elapsed: {time.strftime("%H:%M:%S", time.gmtime(int(time.perf_counter() - stime)))}') # nopep8 + button_widget.configure(state=tk.NORMAL) # Enable Button diff --git a/inference_v4.py b/inference_v4.py new file mode 100644 index 0000000..5d8fd8b --- /dev/null +++ b/inference_v4.py @@ -0,0 +1,525 @@ +import pprint +import argparse +import os + +import cv2 +import librosa +import numpy as np +import soundfile as sf +from tqdm import tqdm + +from lib_v4 import dataset +from lib_v4 import nets +from lib_v4 import spec_utils +import torch + +# Command line text parsing and widget manipulation +from collections import defaultdict +import tkinter as tk +import traceback # Error Message Recent Calls +import time # Timer + + +class VocalRemover(object): + + def __init__(self, data, text_widget: tk.Text): + self.data = data + self.text_widget = text_widget + self.models = defaultdict(lambda: None) + self.devices = defaultdict(lambda: None) + self._load_models() + # self.offset = model.offset + + def _load_models(self): + self.text_widget.write('Loading models...\n') # nopep8 Write Command Text + + # -Instrumental- + if os.path.isfile(data['instrumentalModel']): + device = torch.device('cpu') + model = nets.CascadedASPPNet(self.data['n_fft']) + model.load_state_dict(torch.load(self.data['instrumentalModel'], + map_location=device)) + if torch.cuda.is_available() and self.data['gpu'] >= 0: + device = torch.device('cuda:{}'.format(self.data['gpu'])) + model.to(device) + + self.models['instrumental'] = model + self.devices['instrumental'] = device + # -Vocal- + elif os.path.isfile(data['vocalModel']): + device = torch.device('cpu') + model = nets.CascadedASPPNet(self.data['n_fft']) + model.load_state_dict(torch.load(self.data['vocalModel'], + map_location=device)) + if torch.cuda.is_available() and self.data['gpu'] >= 0: + device = torch.device('cuda:{}'.format(self.data['gpu'])) + model.to(device) + + self.models['vocal'] = model + self.devices['vocal'] = device + # -Stack- + if os.path.isfile(self.data['stackModel']): + device = torch.device('cpu') + model = nets.CascadedASPPNet(self.data['n_fft']) + model.load_state_dict(torch.load(self.data['stackModel'], + map_location=device)) + if torch.cuda.is_available() and self.data['gpu'] >= 0: + device = torch.device('cuda:{}'.format(self.data['gpu'])) + model.to(device) + + self.models['stack'] = model + self.devices['stack'] = device + + self.text_widget.write('Done!\n') + + def _execute(self, X_mag_pad, roi_size, n_window, device, model): + model.eval() + with torch.no_grad(): + preds = [] + for i in tqdm(range(n_window)): + start = i * roi_size + X_mag_window = X_mag_pad[None, :, :, + start:start + self.data['window_size']] + X_mag_window = torch.from_numpy(X_mag_window).to(device) + + pred = model.predict(X_mag_window) + + pred = pred.detach().cpu().numpy() + preds.append(pred[0]) + + pred = np.concatenate(preds, axis=2) + + return pred + + def preprocess(self, X_spec): + X_mag = np.abs(X_spec) + X_phase = np.angle(X_spec) + + return X_mag, X_phase + + def inference(self, X_spec, device, model): + X_mag, X_phase = self.preprocess(X_spec) + + coef = X_mag.max() + X_mag_pre = X_mag / coef + + n_frame = X_mag_pre.shape[2] + pad_l, pad_r, roi_size = dataset.make_padding(n_frame, + self.data['window_size'], model.offset) + n_window = int(np.ceil(n_frame / roi_size)) + + X_mag_pad = np.pad( + X_mag_pre, ((0, 0), (0, 0), (pad_l, pad_r)), mode='constant') + + pred = self._execute(X_mag_pad, roi_size, n_window, + device, model) + pred = pred[:, :, :n_frame] + + return pred * coef, X_mag, np.exp(1.j * X_phase) + + def inference_tta(self, X_spec, device, model): + X_mag, X_phase = self.preprocess(X_spec) + + coef = X_mag.max() + X_mag_pre = X_mag / coef + + n_frame = X_mag_pre.shape[2] + pad_l, pad_r, roi_size = dataset.make_padding(n_frame, + self.data['window_size'], model.offset) + n_window = int(np.ceil(n_frame / roi_size)) + + X_mag_pad = np.pad( + X_mag_pre, ((0, 0), (0, 0), (pad_l, pad_r)), mode='constant') + + pred = self._execute(X_mag_pad, roi_size, n_window, + device, model) + pred = pred[:, :, :n_frame] + + pad_l += roi_size // 2 + pad_r += roi_size // 2 + n_window += 1 + + X_mag_pad = np.pad( + X_mag_pre, ((0, 0), (0, 0), (pad_l, pad_r)), mode='constant') + + pred_tta = self._execute(X_mag_pad, roi_size, n_window, + device, model) + pred_tta = pred_tta[:, :, roi_size // 2:] + pred_tta = pred_tta[:, :, :n_frame] + + return (pred + pred_tta) * 0.5 * coef, X_mag, np.exp(1.j * X_phase) + + +data = { + # Paths + 'input_paths': None, + 'export_path': None, + # Processing Options + 'gpu': -1, + 'postprocess': True, + 'tta': True, + 'output_image': True, + # Models + 'instrumentalModel': None, + 'vocalModel': None, + 'stackModel': None, + 'useModel': None, + # Stack Options + 'stackPasses': 0, + 'stackOnly': False, + 'saveAllStacked': False, + # Constants + 'sr': 44_100, + 'hop_length': 1_024, + 'window_size': 512, + 'n_fft': 2_048, +} +default_sr = data['sr'] +default_hop_length = data['hop_length'] +default_window_size = data['window_size'] +default_n_fft = data['n_fft'] + + +def update_progress(progress_var, total_files, total_loops, file_num, loop_num, step: float = 1): + """Calculate the progress for the progress widget in the GUI""" + base = (100 / total_files) + progress = base * (file_num - 1) + progress += (base / total_loops) * (loop_num + step) + + progress_var.set(progress) + + +def get_baseText(total_files, total_loops, file_num, loop_num): + """Create the base text for the command widget""" + text = 'File {file_num}/{total_files}:{loop} '.format(file_num=file_num, + total_files=total_files, + loop='' if total_loops <= 1 else f' ({loop_num+1}/{total_loops})') + return text + + +def update_constants(model_name): + """ + Decode the conversion settings from the model's name + """ + global data + text = model_name.replace('.pth', '') + text_parts = text.split('_')[1:] + + data['sr'] = default_sr + data['hop_length'] = default_hop_length + data['window_size'] = default_window_size + data['n_fft'] = default_n_fft + + for text_part in text_parts: + if 'sr' in text_part: + text_part = text_part.replace('sr', '') + if text_part.isdecimal(): + try: + data['sr'] = int(text_part) + continue + except ValueError: + # Cannot convert string to int + pass + if 'hl' in text_part: + text_part = text_part.replace('hl', '') + if text_part.isdecimal(): + try: + data['hop_length'] = int(text_part) + continue + except ValueError: + # Cannot convert string to int + pass + if 'w' in text_part: + text_part = text_part.replace('w', '') + if text_part.isdecimal(): + try: + data['window_size'] = int(text_part) + continue + except ValueError: + # Cannot convert string to int + pass + if 'nf' in text_part: + text_part = text_part.replace('nf', '') + if text_part.isdecimal(): + try: + data['n_fft'] = int(text_part) + continue + except ValueError: + # Cannot convert string to int + pass + + +def determineModelFolderName(): + """ + Determine the name that is used for the folder and appended + to the back of the music files + """ + modelFolderName = '' + if not data['modelFolder']: + # Model Test Mode not selected + return modelFolderName + + # -Instrumental- + if os.path.isfile(data['instrumentalModel']): + modelFolderName += os.path.splitext(os.path.basename(data['instrumentalModel']))[0] + '-' + # -Vocal- + elif os.path.isfile(data['vocalModel']): + modelFolderName += os.path.splitext(os.path.basename(data['vocalModel']))[0] + '-' + # -Stack- + if os.path.isfile(data['stackModel']): + modelFolderName += os.path.splitext(os.path.basename(data['stackModel']))[0] + else: + modelFolderName = modelFolderName[:-1] + + if modelFolderName: + modelFolderName = '/' + modelFolderName + + return modelFolderName + + +def main(window: tk.Wm, text_widget: tk.Text, button_widget: tk.Button, progress_var: tk.Variable, + **kwargs: dict): + def save_files(wav_instrument, wav_vocals): + """Save output music files""" + vocal_name = None + instrumental_name = None + folder = '' + + # Get the Suffix Name + if (not loop_num or + loop_num == (total_loops - 1)): # First or Last Loop + if data['stackOnly']: + if loop_num == (total_loops - 1): # Last Loop + if not (total_loops - 1): # Only 1 Loop + vocal_name = '(Vocals)' + instrumental_name = '(Instrumental)' + else: + vocal_name = '(Vocal_Final_Stacked_Output)' + instrumental_name = '(Instrumental_Final_Stacked_Output)' + elif data['useModel'] == 'instrumental': + if not loop_num: # First Loop + vocal_name = '(Vocals)' + if loop_num == (total_loops - 1): # Last Loop + if not (total_loops - 1): # Only 1 Loop + instrumental_name = '(Instrumental)' + else: + instrumental_name = '(Instrumental_Final_Stacked_Output)' + elif data['useModel'] == 'vocal': + if not loop_num: # First Loop + instrumental_name = '(Instrumental)' + if loop_num == (total_loops - 1): # Last Loop + if not (total_loops - 1): # Only 1 Loop + vocal_name = '(Vocals)' + else: + vocal_name = '(Vocals_Final_Stacked_Output)' + if data['useModel'] == 'vocal': + # Reverse names + vocal_name, instrumental_name = instrumental_name, vocal_name + elif data['saveAllStacked']: + folder = os.path.splitext(os.path.basename(base_name))[0] + ' Stacked Outputs' # nopep8 + folder = os.path.basename(folder) + '/' + folder_path = os.path.dirname(base_name) + folder_path = os.path.join(folder_path, folder) + + if not os.path.isdir(folder_path): + os.mkdir(folder_path) + + if data['stackOnly']: + vocal_name = f'(Vocal_{loop_num}_Stacked_Output)' + instrumental_name = f'(Instrumental_{loop_num}_Stacked_Output)' + elif (data['useModel'] == 'vocal' or + data['useModel'] == 'instrumental'): + vocal_name = f'(Vocals_{loop_num}_Stacked_Output)' + instrumental_name = f'(Instrumental_{loop_num}_Stacked_Output)' + + if data['useModel'] == 'vocal': + # Reverse names + vocal_name, instrumental_name = instrumental_name, vocal_name + + # Save Temp File + # For instrumental the instrumental is the temp file + # and for vocal the instrumental is the temp file due + # to reversement + sf.write(f'temp.wav', + wav_instrument.T, sr) + + appendModelFolderName = modelFolderName.replace('/', '_') + # -Save files- + # Instrumental + if instrumental_name is not None: + instrumental_path = '{base_path}/{folder}{file_name}.wav'.format( + base_path=os.path.dirname(base_name), + folder=folder, + file_name=f'{os.path.basename(base_name)}_{instrumental_name}{appendModelFolderName}', + ) + + sf.write(instrumental_path, + wav_instrument.T, sr) + # Vocal + if vocal_name is not None: + vocal_path = '{base_path}/{folder}{file_name}.wav'.format( + base_path=os.path.dirname(base_name), + folder=folder, + file_name=f'{os.path.basename(base_name)}_{vocal_name}{appendModelFolderName}', + ) + sf.write(vocal_path, + wav_vocals.T, sr) + + data.update(kwargs) + + # Update default settings + global default_sr + global default_hop_length + global default_window_size + global default_n_fft + default_sr = data['sr'] + default_hop_length = data['hop_length'] + default_window_size = data['window_size'] + default_n_fft = data['n_fft'] + + stime = time.perf_counter() + progress_var.set(0) + text_widget.clear() + button_widget.configure(state=tk.DISABLED) # Disable Button + + vocal_remover = VocalRemover(data, text_widget) + modelFolderName = determineModelFolderName() + if modelFolderName: + folder_path = f'{data["export_path"]}{modelFolderName}' + if not os.path.isdir(folder_path): + os.mkdir(folder_path) + + # Determine Loops + total_loops = data['stackPasses'] + if not data['stackOnly']: + total_loops += 1 + for file_num, music_file in enumerate(data['input_paths'], start=1): + try: + # Determine File Name + base_name = f'{data["export_path"]}{modelFolderName}/{file_num}_{os.path.splitext(os.path.basename(music_file))[0]}' + + # --Seperate Music Files-- + for loop_num in range(total_loops): + # -Determine which model will be used- + if not loop_num: + # First Iteration + if data['stackOnly']: + if os.path.isfile(data['stackModel']): + model_name = os.path.basename(data['stackModel']) + model = vocal_remover.models['stack'] + device = vocal_remover.devices['stack'] + else: + raise ValueError(f'Selected stack only model, however, stack model path file cannot be found\nPath: "{data["stackModel"]}"') # nopep8 + else: + model_name = os.path.basename(data[f'{data["useModel"]}Model']) + model = vocal_remover.models[data['useModel']] + device = vocal_remover.devices[data['useModel']] + else: + model_name = os.path.basename(data['stackModel']) + # Every other iteration + model = vocal_remover.models['stack'] + device = vocal_remover.devices['stack'] + # Reference new music file + music_file = 'temp.wav' + + # -Get text and update progress- + base_text = get_baseText(total_files=len(data['input_paths']), + total_loops=total_loops, + file_num=file_num, + loop_num=loop_num) + progress_kwargs = {'progress_var': progress_var, + 'total_files': len(data['input_paths']), + 'total_loops': total_loops, + 'file_num': file_num, + 'loop_num': loop_num} + update_progress(**progress_kwargs, + step=0) + update_constants(model_name) + + # -Go through the different steps of seperation- + # Wave source + text_widget.write(base_text + 'Loading wave source...\n') + X, sr = librosa.load(music_file, data['sr'], False, + dtype=np.float32, res_type='kaiser_fast') + if X.ndim == 1: + X = np.asarray([X, X]) + text_widget.write(base_text + 'Done!\n') + + update_progress(**progress_kwargs, + step=0.1) + # Stft of wave source + text_widget.write(base_text + 'Stft of wave source...\n') + X = spec_utils.wave_to_spectrogram(X, + data['hop_length'], data['n_fft']) + if data['tta']: + pred, X_mag, X_phase = vocal_remover.inference_tta(X, + device=device, + model=model) + else: + pred, X_mag, X_phase = vocal_remover.inference(X, + device=device, + model=model) + text_widget.write(base_text + 'Done!\n') + + update_progress(**progress_kwargs, + step=0.6) + # Postprocess + if data['postprocess']: + text_widget.write(base_text + 'Post processing...\n') + pred_inv = np.clip(X_mag - pred, 0, np.inf) + pred = spec_utils.mask_silence(pred, pred_inv) + text_widget.write(base_text + 'Done!\n') + + update_progress(**progress_kwargs, + step=0.65) + + # Inverse stft + text_widget.write(base_text + 'Inverse stft of instruments and vocals...\n') # nopep8 + y_spec = pred * X_phase + wav_instrument = spec_utils.spectrogram_to_wave(y_spec, + hop_length=data['hop_length']) + v_spec = np.clip(X_mag - pred, 0, np.inf) * X_phase + wav_vocals = spec_utils.spectrogram_to_wave(v_spec, + hop_length=data['hop_length']) + text_widget.write(base_text + 'Done!\n') + + update_progress(**progress_kwargs, + step=0.7) + # Save output music files + text_widget.write(base_text + 'Saving Files...\n') + save_files(wav_instrument, wav_vocals) + text_widget.write(base_text + 'Done!\n') + + update_progress(**progress_kwargs, + step=0.8) + else: + # Save output image + if data['output_image']: + with open('{}_Instruments.jpg'.format(base_name), mode='wb') as f: + image = spec_utils.spectrogram_to_image(y_spec) + _, bin_image = cv2.imencode('.jpg', image) + bin_image.tofile(f) + with open('{}_Vocals.jpg'.format(base_name), mode='wb') as f: + image = spec_utils.spectrogram_to_image(v_spec) + _, bin_image = cv2.imencode('.jpg', image) + bin_image.tofile(f) + + text_widget.write(base_text + 'Completed Seperation!\n\n') + except Exception as e: + traceback_text = ''.join(traceback.format_tb(e.__traceback__)) + message = f'Traceback Error: "{traceback_text}"\n{type(e).__name__}: "{e}"\nFile: {music_file}\nLoop: {loop_num}\nPlease contact the creator and attach a screenshot of this error with the file and settings that caused it!' + tk.messagebox.showerror(master=window, + title='Untracked Error', + message=message) + print(traceback_text) + print(type(e).__name__, e) + print(message) + progress_var.set(0) + button_widget.configure(state=tk.NORMAL) # Enable Button + return + + os.remove('temp.wav') + progress_var.set(0) + text_widget.write(f'Conversion(s) Completed and Saving all Files!\n') + text_widget.write(f'Time Elapsed: {time.strftime("%H:%M:%S", time.gmtime(int(time.perf_counter() - stime)))}') # nopep8 + button_widget.configure(state=tk.NORMAL) # Enable Button