diff --git a/inference_v5.py b/inference_v5.py index 09c6d06..80e2b23 100644 --- a/inference_v5.py +++ b/inference_v5.py @@ -1,7 +1,16 @@ +from functools import total_ordering import pprint import argparse import os import importlib +from statistics import mode +import sys +import subprocess +import contextlib +from subprocess import run +from tkinter.ttk import Progressbar +from typing import _SpecialForm, overload +from unittest.mock import _SpecState import cv2 import librosa @@ -20,9 +29,7 @@ from collections import defaultdict import tkinter as tk import traceback # Error Message Recent Calls import time # Timer - - - +import random class VocalRemover(object): @@ -44,7 +51,10 @@ class VocalRemover(object): global args global model_params_d - + + + #progressb = tqdm + p = argparse.ArgumentParser() p.add_argument('--paramone', type=str, default='lib_v5/modelparams/4band_44100.json') p.add_argument('--paramtwo', type=str, default='lib_v5/modelparams/4band_v2.json') @@ -96,85 +106,6 @@ class VocalRemover(object): self.text_widget.write('Done!\n') - def _execute(self, X_mag_pad, roi_size, n_window, device, model, aggressiveness): - model.eval() - with torch.no_grad(): - preds = [] - for i in tqdm(range(n_window)): - start = i * roi_size - X_mag_window = X_mag_pad[None, :, :, start:start + self.data['window_size']] - X_mag_window = torch.from_numpy(X_mag_window).to(device) - - pred = model.predict(X_mag_window, aggressiveness) - - pred = pred.detach().cpu().numpy() - preds.append(pred[0]) - - pred = np.concatenate(preds, axis=2) - - - - return pred - - def preprocess(self, X_spec): - X_mag = np.abs(X_spec) - X_phase = np.angle(X_spec) - - return X_mag, X_phase - - def inference(self, X_spec, device, model, aggressiveness): - X_mag, X_phase = self.preprocess(X_spec) - - coef = X_mag.max() - X_mag_pre = X_mag / coef - - n_frame = X_mag_pre.shape[2] - pad_l, pad_r, roi_size = dataset.make_padding(n_frame, - self.data['window_size'], model.offset) - n_window = int(np.ceil(n_frame / roi_size)) - - X_mag_pad = np.pad( - X_mag_pre, ((0, 0), (0, 0), (pad_l, pad_r)), mode='constant') - - pred = self._execute(X_mag_pad, roi_size, n_window, - device, model, aggressiveness) - pred = pred[:, :, :n_frame] - - return pred * coef, X_mag, np.exp(1.j * X_phase) - - def inference_tta(self, X_spec, device, model, aggressiveness): - X_mag, X_phase = self.preprocess(X_spec) - - coef = X_mag.max() - X_mag_pre = X_mag / coef - - n_frame = X_mag_pre.shape[2] - pad_l, pad_r, roi_size = dataset.make_padding(n_frame, - self.data['window_size'], model.offset) - n_window = int(np.ceil(n_frame / roi_size)) - - X_mag_pad = np.pad( - X_mag_pre, ((0, 0), (0, 0), (pad_l, pad_r)), mode='constant') - - pred = self._execute(X_mag_pad, roi_size, n_window, - device, model, aggressiveness) - pred = pred[:, :, :n_frame] - - pad_l += roi_size // 2 - pad_r += roi_size // 2 - n_window += 1 - - X_mag_pad = np.pad( - X_mag_pre, ((0, 0), (0, 0), (pad_l, pad_r)), mode='constant') - - pred_tta = self._execute(X_mag_pad, roi_size, n_window, - device, model, aggressiveness) - pred_tta = pred_tta[:, :, roi_size // 2:] - pred_tta = pred_tta[:, :, :n_frame] - - return (pred + pred_tta) * 0.5 * coef, X_mag, np.exp(1.j * X_phase) - - data = { # Paths 'input_paths': None, @@ -230,9 +161,9 @@ def determineModelFolderName(): return modelFolderName - def main(window: tk.Wm, text_widget: tk.Text, button_widget: tk.Button, progress_var: tk.Variable, **kwargs: dict): + def save_files(wav_instrument, wav_vocals): """Save output music files""" vocal_name = '(Vocals)' @@ -350,23 +281,102 @@ def main(window: tk.Wm, text_widget: tk.Text, button_widget: tk.Button, progress step=0.1) text_widget.write(base_text + 'Stft of wave source...\n') + + text_widget.write(base_text + 'Done!\n') + + text_widget.write(base_text + "Please Wait..\n") + X_spec_m = spec_utils.combine_spectrograms(X_spec_s, mp) del X_wave, X_spec_s - if data['tta']: - pred, X_mag, X_phase = vocal_remover.inference_tta(X_spec_m, - device, - model, {'value': args.aggressiveness,'split_bin': mp.param['band'][1]['crop_stop']}) - else: - pred, X_mag, X_phase = vocal_remover.inference(X_spec_m, - device, - model, {'value': args.aggressiveness,'split_bin': mp.param['band'][1]['crop_stop']}) + def inference(X_spec, device, model, aggressiveness): + + def _execute(X_mag_pad, roi_size, n_window, device, model, aggressiveness): + model.eval() + + with torch.no_grad(): + preds = [] + + iterations = [n_window] + total_iterations = sum(iterations) + + text_widget.write(base_text + "Length: "f"{total_iterations} Slices\n") + + for i in tqdm(range(n_window)): + update_progress(**progress_kwargs, + step=(0.1 + (0.8/n_window * i))) + start = i * roi_size + X_mag_window = X_mag_pad[None, :, :, start:start + data['window_size']] + X_mag_window = torch.from_numpy(X_mag_window).to(device) + + pred = model.predict(X_mag_window, aggressiveness) + + pred = pred.detach().cpu().numpy() + preds.append(pred[0]) + + pred = np.concatenate(preds, axis=2) + + return pred + + def preprocess(X_spec): + X_mag = np.abs(X_spec) + X_phase = np.angle(X_spec) + + return X_mag, X_phase + + X_mag, X_phase = preprocess(X_spec) + + coef = X_mag.max() + X_mag_pre = X_mag / coef + + n_frame = X_mag_pre.shape[2] + pad_l, pad_r, roi_size = dataset.make_padding(n_frame, + data['window_size'], model.offset) + n_window = int(np.ceil(n_frame / roi_size)) + + X_mag_pad = np.pad( + X_mag_pre, ((0, 0), (0, 0), (pad_l, pad_r)), mode='constant') + + pred = _execute(X_mag_pad, roi_size, n_window, + device, model, aggressiveness) + pred = pred[:, :, :n_frame] + + if data['tta']: + pad_l += roi_size // 2 + pad_r += roi_size // 2 + n_window += 1 + + X_mag_pad = np.pad( + X_mag_pre, ((0, 0), (0, 0), (pad_l, pad_r)), mode='constant') + + pred_tta = _execute(X_mag_pad, roi_size, n_window, + device, model, aggressiveness) + pred_tta = pred_tta[:, :, roi_size // 2:] + pred_tta = pred_tta[:, :, :n_frame] + + return (pred + pred_tta) * 0.5 * coef, X_mag, np.exp(1.j * X_phase) + else: + return pred * coef, X_mag, np.exp(1.j * X_phase) + + + aggressiveness = {'value': args.aggressiveness, 'split_bin': mp.param['band'][1]['crop_stop']} + + + if data['tta']: + text_widget.write(base_text + "Running Inferences (TTA)...\n") + else: + text_widget.write(base_text + "Running Inference...\n") + + pred, X_mag, X_phase = inference(X_spec_m, + device, + model, aggressiveness) + text_widget.write(base_text + 'Done!\n') update_progress(**progress_kwargs, - step=0.6) + step=0.9) # Postprocess if data['postprocess']: text_widget.write(base_text + 'Post processing...\n') @@ -375,7 +385,7 @@ def main(window: tk.Wm, text_widget: tk.Text, button_widget: tk.Button, progress text_widget.write(base_text + 'Done!\n') update_progress(**progress_kwargs, - step=0.65) + step=0.95) # Inverse stft text_widget.write(base_text + 'Inverse stft of instruments and vocals...\n') # nopep8