Add files via upload
This commit is contained in:
198
inference_v5.py
198
inference_v5.py
@@ -1,7 +1,16 @@
|
|||||||
|
from functools import total_ordering
|
||||||
import pprint
|
import pprint
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import importlib
|
import importlib
|
||||||
|
from statistics import mode
|
||||||
|
import sys
|
||||||
|
import subprocess
|
||||||
|
import contextlib
|
||||||
|
from subprocess import run
|
||||||
|
from tkinter.ttk import Progressbar
|
||||||
|
from typing import _SpecialForm, overload
|
||||||
|
from unittest.mock import _SpecState
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import librosa
|
import librosa
|
||||||
@@ -20,9 +29,7 @@ from collections import defaultdict
|
|||||||
import tkinter as tk
|
import tkinter as tk
|
||||||
import traceback # Error Message Recent Calls
|
import traceback # Error Message Recent Calls
|
||||||
import time # Timer
|
import time # Timer
|
||||||
|
import random
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class VocalRemover(object):
|
class VocalRemover(object):
|
||||||
|
|
||||||
@@ -44,7 +51,10 @@ class VocalRemover(object):
|
|||||||
|
|
||||||
global args
|
global args
|
||||||
global model_params_d
|
global model_params_d
|
||||||
|
|
||||||
|
|
||||||
|
#progressb = tqdm
|
||||||
|
|
||||||
p = argparse.ArgumentParser()
|
p = argparse.ArgumentParser()
|
||||||
p.add_argument('--paramone', type=str, default='lib_v5/modelparams/4band_44100.json')
|
p.add_argument('--paramone', type=str, default='lib_v5/modelparams/4band_44100.json')
|
||||||
p.add_argument('--paramtwo', type=str, default='lib_v5/modelparams/4band_v2.json')
|
p.add_argument('--paramtwo', type=str, default='lib_v5/modelparams/4band_v2.json')
|
||||||
@@ -96,85 +106,6 @@ class VocalRemover(object):
|
|||||||
|
|
||||||
self.text_widget.write('Done!\n')
|
self.text_widget.write('Done!\n')
|
||||||
|
|
||||||
def _execute(self, X_mag_pad, roi_size, n_window, device, model, aggressiveness):
|
|
||||||
model.eval()
|
|
||||||
with torch.no_grad():
|
|
||||||
preds = []
|
|
||||||
for i in tqdm(range(n_window)):
|
|
||||||
start = i * roi_size
|
|
||||||
X_mag_window = X_mag_pad[None, :, :, start:start + self.data['window_size']]
|
|
||||||
X_mag_window = torch.from_numpy(X_mag_window).to(device)
|
|
||||||
|
|
||||||
pred = model.predict(X_mag_window, aggressiveness)
|
|
||||||
|
|
||||||
pred = pred.detach().cpu().numpy()
|
|
||||||
preds.append(pred[0])
|
|
||||||
|
|
||||||
pred = np.concatenate(preds, axis=2)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
return pred
|
|
||||||
|
|
||||||
def preprocess(self, X_spec):
|
|
||||||
X_mag = np.abs(X_spec)
|
|
||||||
X_phase = np.angle(X_spec)
|
|
||||||
|
|
||||||
return X_mag, X_phase
|
|
||||||
|
|
||||||
def inference(self, X_spec, device, model, aggressiveness):
|
|
||||||
X_mag, X_phase = self.preprocess(X_spec)
|
|
||||||
|
|
||||||
coef = X_mag.max()
|
|
||||||
X_mag_pre = X_mag / coef
|
|
||||||
|
|
||||||
n_frame = X_mag_pre.shape[2]
|
|
||||||
pad_l, pad_r, roi_size = dataset.make_padding(n_frame,
|
|
||||||
self.data['window_size'], model.offset)
|
|
||||||
n_window = int(np.ceil(n_frame / roi_size))
|
|
||||||
|
|
||||||
X_mag_pad = np.pad(
|
|
||||||
X_mag_pre, ((0, 0), (0, 0), (pad_l, pad_r)), mode='constant')
|
|
||||||
|
|
||||||
pred = self._execute(X_mag_pad, roi_size, n_window,
|
|
||||||
device, model, aggressiveness)
|
|
||||||
pred = pred[:, :, :n_frame]
|
|
||||||
|
|
||||||
return pred * coef, X_mag, np.exp(1.j * X_phase)
|
|
||||||
|
|
||||||
def inference_tta(self, X_spec, device, model, aggressiveness):
|
|
||||||
X_mag, X_phase = self.preprocess(X_spec)
|
|
||||||
|
|
||||||
coef = X_mag.max()
|
|
||||||
X_mag_pre = X_mag / coef
|
|
||||||
|
|
||||||
n_frame = X_mag_pre.shape[2]
|
|
||||||
pad_l, pad_r, roi_size = dataset.make_padding(n_frame,
|
|
||||||
self.data['window_size'], model.offset)
|
|
||||||
n_window = int(np.ceil(n_frame / roi_size))
|
|
||||||
|
|
||||||
X_mag_pad = np.pad(
|
|
||||||
X_mag_pre, ((0, 0), (0, 0), (pad_l, pad_r)), mode='constant')
|
|
||||||
|
|
||||||
pred = self._execute(X_mag_pad, roi_size, n_window,
|
|
||||||
device, model, aggressiveness)
|
|
||||||
pred = pred[:, :, :n_frame]
|
|
||||||
|
|
||||||
pad_l += roi_size // 2
|
|
||||||
pad_r += roi_size // 2
|
|
||||||
n_window += 1
|
|
||||||
|
|
||||||
X_mag_pad = np.pad(
|
|
||||||
X_mag_pre, ((0, 0), (0, 0), (pad_l, pad_r)), mode='constant')
|
|
||||||
|
|
||||||
pred_tta = self._execute(X_mag_pad, roi_size, n_window,
|
|
||||||
device, model, aggressiveness)
|
|
||||||
pred_tta = pred_tta[:, :, roi_size // 2:]
|
|
||||||
pred_tta = pred_tta[:, :, :n_frame]
|
|
||||||
|
|
||||||
return (pred + pred_tta) * 0.5 * coef, X_mag, np.exp(1.j * X_phase)
|
|
||||||
|
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
# Paths
|
# Paths
|
||||||
'input_paths': None,
|
'input_paths': None,
|
||||||
@@ -230,9 +161,9 @@ def determineModelFolderName():
|
|||||||
|
|
||||||
return modelFolderName
|
return modelFolderName
|
||||||
|
|
||||||
|
|
||||||
def main(window: tk.Wm, text_widget: tk.Text, button_widget: tk.Button, progress_var: tk.Variable,
|
def main(window: tk.Wm, text_widget: tk.Text, button_widget: tk.Button, progress_var: tk.Variable,
|
||||||
**kwargs: dict):
|
**kwargs: dict):
|
||||||
|
|
||||||
def save_files(wav_instrument, wav_vocals):
|
def save_files(wav_instrument, wav_vocals):
|
||||||
"""Save output music files"""
|
"""Save output music files"""
|
||||||
vocal_name = '(Vocals)'
|
vocal_name = '(Vocals)'
|
||||||
@@ -350,23 +281,102 @@ def main(window: tk.Wm, text_widget: tk.Text, button_widget: tk.Button, progress
|
|||||||
step=0.1)
|
step=0.1)
|
||||||
|
|
||||||
text_widget.write(base_text + 'Stft of wave source...\n')
|
text_widget.write(base_text + 'Stft of wave source...\n')
|
||||||
|
|
||||||
|
text_widget.write(base_text + 'Done!\n')
|
||||||
|
|
||||||
|
text_widget.write(base_text + "Please Wait..\n")
|
||||||
|
|
||||||
X_spec_m = spec_utils.combine_spectrograms(X_spec_s, mp)
|
X_spec_m = spec_utils.combine_spectrograms(X_spec_s, mp)
|
||||||
|
|
||||||
del X_wave, X_spec_s
|
del X_wave, X_spec_s
|
||||||
|
|
||||||
if data['tta']:
|
def inference(X_spec, device, model, aggressiveness):
|
||||||
pred, X_mag, X_phase = vocal_remover.inference_tta(X_spec_m,
|
|
||||||
device,
|
def _execute(X_mag_pad, roi_size, n_window, device, model, aggressiveness):
|
||||||
model, {'value': args.aggressiveness,'split_bin': mp.param['band'][1]['crop_stop']})
|
model.eval()
|
||||||
else:
|
|
||||||
pred, X_mag, X_phase = vocal_remover.inference(X_spec_m,
|
with torch.no_grad():
|
||||||
device,
|
preds = []
|
||||||
model, {'value': args.aggressiveness,'split_bin': mp.param['band'][1]['crop_stop']})
|
|
||||||
|
iterations = [n_window]
|
||||||
|
|
||||||
|
total_iterations = sum(iterations)
|
||||||
|
|
||||||
|
text_widget.write(base_text + "Length: "f"{total_iterations} Slices\n")
|
||||||
|
|
||||||
|
for i in tqdm(range(n_window)):
|
||||||
|
update_progress(**progress_kwargs,
|
||||||
|
step=(0.1 + (0.8/n_window * i)))
|
||||||
|
start = i * roi_size
|
||||||
|
X_mag_window = X_mag_pad[None, :, :, start:start + data['window_size']]
|
||||||
|
X_mag_window = torch.from_numpy(X_mag_window).to(device)
|
||||||
|
|
||||||
|
pred = model.predict(X_mag_window, aggressiveness)
|
||||||
|
|
||||||
|
pred = pred.detach().cpu().numpy()
|
||||||
|
preds.append(pred[0])
|
||||||
|
|
||||||
|
pred = np.concatenate(preds, axis=2)
|
||||||
|
|
||||||
|
return pred
|
||||||
|
|
||||||
|
def preprocess(X_spec):
|
||||||
|
X_mag = np.abs(X_spec)
|
||||||
|
X_phase = np.angle(X_spec)
|
||||||
|
|
||||||
|
return X_mag, X_phase
|
||||||
|
|
||||||
|
X_mag, X_phase = preprocess(X_spec)
|
||||||
|
|
||||||
|
coef = X_mag.max()
|
||||||
|
X_mag_pre = X_mag / coef
|
||||||
|
|
||||||
|
n_frame = X_mag_pre.shape[2]
|
||||||
|
pad_l, pad_r, roi_size = dataset.make_padding(n_frame,
|
||||||
|
data['window_size'], model.offset)
|
||||||
|
n_window = int(np.ceil(n_frame / roi_size))
|
||||||
|
|
||||||
|
X_mag_pad = np.pad(
|
||||||
|
X_mag_pre, ((0, 0), (0, 0), (pad_l, pad_r)), mode='constant')
|
||||||
|
|
||||||
|
pred = _execute(X_mag_pad, roi_size, n_window,
|
||||||
|
device, model, aggressiveness)
|
||||||
|
pred = pred[:, :, :n_frame]
|
||||||
|
|
||||||
|
if data['tta']:
|
||||||
|
pad_l += roi_size // 2
|
||||||
|
pad_r += roi_size // 2
|
||||||
|
n_window += 1
|
||||||
|
|
||||||
|
X_mag_pad = np.pad(
|
||||||
|
X_mag_pre, ((0, 0), (0, 0), (pad_l, pad_r)), mode='constant')
|
||||||
|
|
||||||
|
pred_tta = _execute(X_mag_pad, roi_size, n_window,
|
||||||
|
device, model, aggressiveness)
|
||||||
|
pred_tta = pred_tta[:, :, roi_size // 2:]
|
||||||
|
pred_tta = pred_tta[:, :, :n_frame]
|
||||||
|
|
||||||
|
return (pred + pred_tta) * 0.5 * coef, X_mag, np.exp(1.j * X_phase)
|
||||||
|
else:
|
||||||
|
return pred * coef, X_mag, np.exp(1.j * X_phase)
|
||||||
|
|
||||||
|
|
||||||
|
aggressiveness = {'value': args.aggressiveness, 'split_bin': mp.param['band'][1]['crop_stop']}
|
||||||
|
|
||||||
|
|
||||||
|
if data['tta']:
|
||||||
|
text_widget.write(base_text + "Running Inferences (TTA)...\n")
|
||||||
|
else:
|
||||||
|
text_widget.write(base_text + "Running Inference...\n")
|
||||||
|
|
||||||
|
pred, X_mag, X_phase = inference(X_spec_m,
|
||||||
|
device,
|
||||||
|
model, aggressiveness)
|
||||||
|
|
||||||
text_widget.write(base_text + 'Done!\n')
|
text_widget.write(base_text + 'Done!\n')
|
||||||
|
|
||||||
update_progress(**progress_kwargs,
|
update_progress(**progress_kwargs,
|
||||||
step=0.6)
|
step=0.9)
|
||||||
# Postprocess
|
# Postprocess
|
||||||
if data['postprocess']:
|
if data['postprocess']:
|
||||||
text_widget.write(base_text + 'Post processing...\n')
|
text_widget.write(base_text + 'Post processing...\n')
|
||||||
@@ -375,7 +385,7 @@ def main(window: tk.Wm, text_widget: tk.Text, button_widget: tk.Button, progress
|
|||||||
text_widget.write(base_text + 'Done!\n')
|
text_widget.write(base_text + 'Done!\n')
|
||||||
|
|
||||||
update_progress(**progress_kwargs,
|
update_progress(**progress_kwargs,
|
||||||
step=0.65)
|
step=0.95)
|
||||||
|
|
||||||
# Inverse stft
|
# Inverse stft
|
||||||
text_widget.write(base_text + 'Inverse stft of instruments and vocals...\n') # nopep8
|
text_widget.write(base_text + 'Inverse stft of instruments and vocals...\n') # nopep8
|
||||||
|
|||||||
Reference in New Issue
Block a user