From fe3f6d883baea0ef086b77f504d81f4ff690375e Mon Sep 17 00:00:00 2001 From: Anjok07 <68268275+Anjok07@users.noreply.github.com> Date: Mon, 9 Nov 2020 04:26:48 -0600 Subject: [PATCH] Delete inference.py --- inference.py | 200 --------------------------------------------------- 1 file changed, 200 deletions(-) delete mode 100644 inference.py diff --git a/inference.py b/inference.py deleted file mode 100644 index 2b05552..0000000 --- a/inference.py +++ /dev/null @@ -1,200 +0,0 @@ -import argparse -import os - -import cv2 -import librosa -import numpy as np -import soundfile as sf -from tqdm import tqdm - -from lib import dataset -from lib import nets -from lib import spec_utils - -# Variable manipulation and command line text parsing -import torch -import tkinter as tk -import traceback # Error Message Recent Calls - - -class Namespace: - """ - Replaces ArgumentParser - """ - - def __init__(self, **kwargs): - self.__dict__.update(kwargs) - - -def main(window: tk.Wm, input_paths: list, gpu: bool = -1, - model: str = 'models/baseline.pth', sr: int = 44100, hop_length: int = 1024, - window_size: int = 512, out_mask: bool = False, postprocess: bool = False, - export_path: str = '', loops: int = 1, - # Other Variables (Tkinter) - progress_var: tk.Variable = None, button_widget: tk.Button = None, command_widget: tk.Text = None, - ): - def load_model(): - args.command_widget.write('Loading model...\n') # nopep8 Write Command Text - device = torch.device('cpu') - model = nets.CascadedASPPNet() - model.load_state_dict(torch.load(args.model, map_location=device)) - if torch.cuda.is_available() and args.gpu >= 0: - device = torch.device('cuda:{}'.format(args.gpu)) - model.to(device) - args.command_widget.write('Done!\n') # nopep8 Write Command Text - - return model, device - - def load_wave_source(): - args.command_widget.write(base_text + 'Loading wave source...\n') # nopep8 Write Command Text - X, sr = librosa.load(music_file, - args.sr, - False, - dtype=np.float32, - res_type='kaiser_fast') - args.command_widget.write(base_text + 'Done!\n') # nopep8 Write Command Text - - return X, sr - - def stft_wave_source(X): - args.command_widget.write(base_text + 'Stft of wave source...\n') # nopep8 Write Command Text - X = spec_utils.calc_spec(X, args.hop_length) - X, phase = np.abs(X), np.exp(1.j * np.angle(X)) - coeff = X.max() - X /= coeff - - offset = model.offset - l, r, roi_size = dataset.make_padding( - X.shape[2], args.window_size, offset) - X_pad = np.pad(X, ((0, 0), (0, 0), (l, r)), mode='constant') - X_roll = np.roll(X_pad, roi_size // 2, axis=2) - - model.eval() - with torch.no_grad(): - masks = [] - masks_roll = [] - length = int(np.ceil(X.shape[2] / roi_size)) - for i in tqdm(range(length)): - progress_var.set(base_progress + max_progress * (0.1 + (0.6/length * i))) # nopep8 Update Progress - start = i * roi_size - X_window = torch.from_numpy(np.asarray([ - X_pad[:, :, start:start + args.window_size], - X_roll[:, :, start:start + args.window_size] - ])).to(device) - pred = model.predict(X_window) - pred = pred.detach().cpu().numpy() - masks.append(pred[0]) - masks_roll.append(pred[1]) - - mask = np.concatenate(masks, axis=2)[:, :, :X.shape[2]] - mask_roll = np.concatenate(masks_roll, axis=2)[ - :, :, :X.shape[2]] - mask = (mask + np.roll(mask_roll, -roi_size // 2, axis=2)) / 2 - - if args.postprocess: - vocal = X * (1 - mask) * coeff - mask = spec_utils.mask_uninformative(mask, vocal) - args.command_widget.write(base_text + 'Done!\n') # nopep8 Write Command Text - - inst = X * mask * coeff - vocal = X * (1 - mask) * coeff - - return inst, vocal, phase, mask - - def invert_instrum_vocal(inst, vocal, phase): - args.command_widget.write(base_text + 'Inverse stft of instruments and vocals...\n') # nopep8 Write Command Text - - wav_instrument = spec_utils.spec_to_wav(inst, phase, args.hop_length) # nopep8 - wav_vocals = spec_utils.spec_to_wav(vocal, phase, args.hop_length) # nopep8 - - args.command_widget.write(base_text + 'Done!\n') # nopep8 Write Command Text - - return wav_instrument, wav_vocals - - def save_files(wav_instrument, wav_vocals): - args.command_widget.write(base_text + 'Saving Files...\n') # nopep8 Write Command Text - sf.write(f'{export_path}/{base_name}_(Instrumental).wav', - wav_instrument.T, sr) - if cur_loop == 0: - sf.write(f'{export_path}/{base_name}_(Vocals).wav', - wav_vocals.T, sr) - if (cur_loop == (args.loops - 1) and - args.loops > 1): - sf.write(f'{export_path}/{base_name}_(Last_Vocals).wav', - wav_vocals.T, sr) - - args.command_widget.write(base_text + 'Done!\n') # nopep8 Write Command Text - - def create_mask(): - args.command_widget.write(base_text + 'Creating Mask...\n') # nopep8 Write Command Text - norm_mask = np.uint8((1 - mask) * 255).transpose(1, 2, 0) - norm_mask = np.concatenate([ - np.max(norm_mask, axis=2, keepdims=True), - norm_mask], axis=2)[::-1] - _, bin_mask = cv2.imencode('.png', norm_mask) - args.command_widget.write(base_text + 'Saving Mask...\n') # nopep8 Write Command Text - with open(f'{export_path}/{base_name}_(Mask).png', mode='wb') as f: - bin_mask.tofile(f) - args.command_widget.write(base_text + 'Done!\n') # nopep8 Write Command Text - - args = Namespace(input=input_paths, gpu=gpu, model=model, - sr=sr, hop_length=hop_length, window_size=window_size, - out_mask=out_mask, postprocess=postprocess, export=export_path, - loops=loops, - # Other Variables (Tkinter) - window=window, progress_var=progress_var, - button_widget=button_widget, command_widget=command_widget, - ) - args.command_widget.clear() # Clear Command Text - args.button_widget.configure(state=tk.DISABLED) # Disable Button - total_files = len(args.input) # Used to calculate progress - - model, device = load_model() - - for file_num, music_file in enumerate(args.input, start=1): - try: - base_name = f'{file_num}_{os.path.splitext(os.path.basename(music_file))[0]}' - for cur_loop in range(args.loops): - if cur_loop > 0: - args.command_widget.write(f'File {file_num}/{total_files}: ' + 'Next Pass!\n') # nopep8 Write Command Text - music_file = f'{export_path}/{base_name}_(Instrumental).wav' - base_progress = 100 / \ - (total_files*args.loops) * \ - ((file_num*args.loops)-((args.loops-1) - cur_loop)-1) - base_text = 'File {file_num}/{total_files}:{loop} '.format( - file_num=file_num, - total_files=total_files, - loop='' if args.loops <= 1 else f' ({cur_loop+1}/{args.loops})') - max_progress = 100 / (total_files*args.loops) - progress_var.set(base_progress + max_progress * 0.05) # nopep8 Update Progress - - X, sr = load_wave_source() - progress_var.set(base_progress + max_progress * 0.1) # nopep8 Update Progress - - inst, vocal, phase, mask = stft_wave_source(X) - progress_var.set(base_progress + max_progress * 0.7) # nopep8 Update Progress - - wav_instrument, wav_vocals = invert_instrum_vocal(inst, vocal, phase) # nopep8 - progress_var.set(base_progress + max_progress * 0.8) # nopep8 Update Progress - - save_files(wav_instrument, wav_vocals) - progress_var.set(base_progress + max_progress * 0.9) # nopep8 Update Progress - - if args.out_mask: - create_mask() - progress_var.set(base_progress + max_progress * 1) # nopep8 Update Progress - - args.command_widget.write(base_text + 'Completed Seperation!\n\n') # nopep8 Write Command Text - except Exception as e: - traceback_text = ''.join(traceback.format_tb(e.__traceback__)) - print(traceback_text) - print(type(e).__name__, e) - tk.messagebox.showerror(master=args.window, - title='Untracked Error', - message=f'Traceback Error: "{traceback_text}"\n{type(e).__name__}: "{e}"\nFile: {music_file}\n\nPlease contact the creator and attach a screenshot of this error with the file which caused it!') - args.button_widget.configure(state=tk.NORMAL) # Enable Button - return - - progress_var.set(100) # Update Progress - args.command_widget.write(f'Conversion(s) Completed and Saving all Files!') # nopep8 Write Command Text - args.button_widget.configure(state=tk.NORMAL) # Enable Button