Add files via upload

This commit is contained in:
Anjok07
2022-07-23 02:56:57 -05:00
committed by GitHub
parent 063f015d5d
commit b011d37a58
6 changed files with 6150 additions and 2150 deletions

View File

@@ -1,37 +1,37 @@
import os
from pathlib import Path
import os.path
from datetime import datetime
import pydub
import shutil
from random import randrange
#MDX-Net
#----------------------------------------
import soundfile as sf
import torch
import numpy as np
from demucs.pretrained import get_model as _gm
from demucs.hdemucs import HDemucs
from demucs.apply import BagOfModels, apply_model
from demucs.audio import AudioFile
import time
import os
from tqdm import tqdm
import warnings
import sys
import librosa
import psutil
#----------------------------------------
from demucs.hdemucs import HDemucs
from demucs.model_v2 import Demucs
from demucs.pretrained import get_model as _gm
from demucs.tasnet_v2 import ConvTasNet
from demucs.utils import apply_model_v1
from demucs.utils import apply_model_v2
from diffq import DiffQuantizer
from lib_v5 import spec_utils
from lib_v5.model_param_init import ModelParameters
import torch
# Command line text parsing and widget manipulation
import tkinter as tk
import traceback # Error Message Recent Calls
from pathlib import Path
from random import randrange
from tqdm import tqdm
import gzip
import io
import librosa
import numpy as np
import os
import os
import os.path
import psutil
import pydub
import shutil
import soundfile as sf
import sys
import time
import time # Timer
import tkinter as tk
import torch
import torch.hub
import traceback # Error Message Recent Calls
import warnings
import zlib
class Predictor():
def __init__(self):
@@ -46,40 +46,62 @@ class Predictor():
if data['gpu'] == -1:
device = torch.device('cpu')
self.demucs = HDemucs(sources=["drums", "bass", "other", "vocals"])
widget_text.write(base_text + 'Loading Demucs model... ')
update_progress(**progress_kwargs,
step=0.05)
path_d = Path('models/Demucs_Models')
print('What Demucs model was chosen? ', data['DemucsModel'])
self.demucs = _gm(name=data['DemucsModel'], repo=path_d)
widget_text.write('Done!\n')
if 'UVR' in data['DemucsModel']:
widget_text.write(base_text + "2 stem model selected.\n")
if isinstance(self.demucs, BagOfModels):
widget_text.write(base_text + f"Selected model is a bag of {len(self.demucs.models)} models.\n")
if data['segment'] == 'None':
segment = None
if isinstance(self.demucs, BagOfModels):
if segment is not None:
for sub in self.demucs.models:
sub.segment = segment
if demucs_model_version == 'v1':
load_from = "models/Demucs_Models/"f"{demucs_model_set_name}"
if str(load_from).endswith(".gz"):
load_from = gzip.open(load_from, "rb")
klass, args, kwargs, state = torch.load(load_from)
self.demucs = klass(*args, **kwargs)
widget_text.write(base_text + 'Loading Demucs v1 model... ')
update_progress(**progress_kwargs,
step=0.05)
self.demucs.to(device)
self.demucs.load_state_dict(state)
widget_text.write('Done!\n')
if not data['segment'] == 'None':
widget_text.write(base_text + 'Segments is only available in Demucs v3. Please use \"Chunks\" instead.\n')
else:
if segment is not None:
sub.segment = segment
else:
try:
segment = int(data['segment'])
if isinstance(self.demucs, BagOfModels):
if segment is not None:
for sub in self.demucs.models:
sub.segment = segment
else:
if segment is not None:
sub.segment = segment
widget_text.write(base_text + "Segments set to "f"{segment}.\n")
except:
pass
if demucs_model_version == 'v2':
if '48' in demucs_model_set_name:
channels=48
elif 'unittest' in demucs_model_set_name:
channels=4
else:
channels=64
if 'tasnet' in demucs_model_set_name:
self.demucs = ConvTasNet(sources=["drums", "bass", "other", "vocals"], X=10)
else:
self.demucs = Demucs(sources=["drums", "bass", "other", "vocals"], channels=channels)
widget_text.write(base_text + 'Loading Demucs v2 model... ')
update_progress(**progress_kwargs,
step=0.05)
self.demucs.to(device)
self.demucs.load_state_dict(torch.load("models/Demucs_Models/"f"{demucs_model_set_name}"))
widget_text.write('Done!\n')
if not data['segment'] == 'None':
widget_text.write(base_text + 'Segments is only available in Demucs v3. Please use \"Chunks\" instead.\n')
else:
pass
self.demucs.eval()
if demucs_model_version == 'v3':
self.demucs = HDemucs(sources=["drums", "bass", "other", "vocals"])
widget_text.write(base_text + 'Loading Demucs model... ')
update_progress(**progress_kwargs,
step=0.05)
path_d = Path('models/Demucs_Models/v3_repo')
#print('What Demucs model was chosen? ', demucs_model_set_name)
self.demucs = _gm(name=demucs_model_set_name, repo=path_d)
widget_text.write('Done!\n')
if 'UVR' in data['DemucsModel']:
widget_text.write(base_text + "2 stem model selected.\n")
if isinstance(self.demucs, BagOfModels):
widget_text.write(base_text + f"Selected model is a bag of {len(self.demucs.models)} models.\n")
if data['segment'] == 'None':
segment = None
if isinstance(self.demucs, BagOfModels):
if segment is not None:
@@ -88,9 +110,29 @@ class Predictor():
else:
if segment is not None:
sub.segment = segment
self.demucs.to(device)
self.demucs.eval()
else:
try:
segment = int(data['segment'])
if isinstance(self.demucs, BagOfModels):
if segment is not None:
for sub in self.demucs.models:
sub.segment = segment
else:
if segment is not None:
sub.segment = segment
widget_text.write(base_text + "Segments set to "f"{segment}.\n")
except:
segment = None
if isinstance(self.demucs, BagOfModels):
if segment is not None:
for sub in self.demucs.models:
sub.segment = segment
else:
if segment is not None:
sub.segment = segment
self.demucs.to(device)
self.demucs.eval()
update_progress(**progress_kwargs,
step=0.1)
@@ -646,7 +688,12 @@ class Predictor():
if end == samples:
break
sources = self.demix_demucs(segmented_mix, margin_size=margin)
if demucs_model_version == 'v1':
sources = self.demix_demucs_v1(segmented_mix, margin_size=margin)
if demucs_model_version == 'v2':
sources = self.demix_demucs_v2(segmented_mix, margin_size=margin)
if demucs_model_version == 'v3':
sources = self.demix_demucs(segmented_mix, margin_size=margin)
return sources
@@ -683,31 +730,94 @@ class Predictor():
sources = np.concatenate(sources, axis=-1)
widget_text.write('Done!\n')
return sources
def demix_demucs_v1(self, mix, margin_size):
processed = {}
demucsitera = len(mix)
demucsitera_calc = demucsitera * 2
gui_progress_bar_demucs = 0
widget_text.write(base_text + "Running Demucs v1 Inference...\n")
widget_text.write(base_text + "Processing "f"{len(mix)} slices... ")
print(' Running Demucs Inference...')
for nmix in mix:
gui_progress_bar_demucs += 1
update_progress(**progress_kwargs,
step=(0.35 + (1.05/demucsitera_calc * gui_progress_bar_demucs)))
cmix = mix[nmix]
cmix = torch.tensor(cmix, dtype=torch.float32)
ref = cmix.mean(0)
cmix = (cmix - ref.mean()) / ref.std()
with torch.no_grad():
sources = apply_model_v1(self.demucs, cmix.to(device), split=split_mode, shifts=shift_set)
sources = (sources * ref.std() + ref.mean()).cpu().numpy()
sources[[0,1]] = sources[[1,0]]
start = 0 if nmix == 0 else margin_size
end = None if nmix == list(mix.keys())[::-1][0] else -margin_size
if margin_size == 0:
end = None
processed[nmix] = sources[:,:,start:end].copy()
sources = list(processed.values())
sources = np.concatenate(sources, axis=-1)
widget_text.write('Done!\n')
return sources
def demix_demucs_v2(self, mix, margin_size):
processed = {}
demucsitera = len(mix)
demucsitera_calc = demucsitera * 2
gui_progress_bar_demucs = 0
widget_text.write(base_text + "Running Demucs v2 Inference...\n")
widget_text.write(base_text + "Processing "f"{len(mix)} slices... ")
print(' Running Demucs Inference...')
for nmix in mix:
gui_progress_bar_demucs += 1
update_progress(**progress_kwargs,
step=(0.35 + (1.05/demucsitera_calc * gui_progress_bar_demucs)))
cmix = mix[nmix]
cmix = torch.tensor(cmix, dtype=torch.float32)
ref = cmix.mean(0)
cmix = (cmix - ref.mean()) / ref.std()
shift_set = 0
with torch.no_grad():
sources = apply_model_v2(self.demucs, cmix.to(device), split=split_mode, overlap=overlap_set, shifts=shift_set)
sources = (sources * ref.std() + ref.mean()).cpu().numpy()
sources[[0,1]] = sources[[1,0]]
start = 0 if nmix == 0 else margin_size
end = None if nmix == list(mix.keys())[::-1][0] else -margin_size
if margin_size == 0:
end = None
processed[nmix] = sources[:,:,start:end].copy()
sources = list(processed.values())
sources = np.concatenate(sources, axis=-1)
widget_text.write('Done!\n')
return sources
data = {
# Paths
'input_paths': None,
'export_path': None,
'saveFormat': 'Wav',
# Processing Options
'demucsmodel': True,
'gpu': -1,
'audfile': True,
'chunks_d': 'Full',
'settest': False,
'voc_only_b': False,
'inst_only_b': False,
'overlap_b': 0.25,
'shifts_b': 2,
'segment': 'None',
'margin': 44100,
'split_mode': False,
'normalize': False,
'compensate': 1.03597672895,
'demucs_stems': 'All Stems',
'DemucsModel': 'mdx_extra',
'audfile': True,
'wavtype': 'PCM_16',
'demucsmodel': True,
'export_path': None,
'gpu': -1,
'input_paths': None,
'inst_only_b': False,
'margin': 44100,
'mp3bit': '320k',
'normalize': False,
'overlap_b': 0.25,
'saveFormat': 'Wav',
'segment': 'None',
'settest': False,
'shifts_b': 2,
'split_mode': False,
'voc_only_b': False,
'wavtype': 'PCM_16',
}
default_chunks = data['chunks_d']
@@ -756,6 +866,8 @@ def main(window: tk.Wm, text_widget: tk.Text, button_widget: tk.Button, progress
global shift_set
global source_val
global split_mode
global demucs_model_set_name
global demucs_model_version
global wav_type_set
global flac_type_set
@@ -817,6 +929,69 @@ def main(window: tk.Wm, text_widget: tk.Text, button_widget: tk.Button, progress
progress_var.set(0)
text_widget.clear()
button_widget.configure(state=tk.DISABLED) # Disable Button
if data['DemucsModel'] == "Tasnet v1":
demucs_model_set_name = 'tasnet.th'
demucs_model_version = 'v1'
elif data['DemucsModel'] == "Tasnet_extra v1":
demucs_model_set_name = 'tasnet_extra.th'
demucs_model_version = 'v1'
elif data['DemucsModel'] == "Demucs v1":
demucs_model_set_name = 'demucs.th'
demucs_model_version = 'v1'
elif data['DemucsModel'] == "Demucs v1.gz":
demucs_model_set_name = 'demucs.th.gz'
demucs_model_version = 'v1'
elif data['DemucsModel'] == "Demucs_extra v1":
demucs_model_set_name = 'demucs_extra.th'
demucs_model_version = 'v1'
elif data['DemucsModel'] == "Demucs_extra v1.gz":
demucs_model_set_name = 'demucs_extra.th.gz'
demucs_model_version = 'v1'
elif data['DemucsModel'] == "Light v1":
demucs_model_set_name = 'light.th'
demucs_model_version = 'v1'
elif data['DemucsModel'] == "Light v1.gz":
demucs_model_set_name = 'light.th.gz'
demucs_model_version = 'v1'
elif data['DemucsModel'] == "Light_extra v1":
demucs_model_set_name = 'light_extra.th'
demucs_model_version = 'v1'
elif data['DemucsModel'] == "Light_extra v1.gz":
demucs_model_set_name = 'light_extra.th.gz'
demucs_model_version = 'v1'
elif data['DemucsModel'] == "Tasnet v2":
demucs_model_set_name = 'tasnet-beb46fac.th'
demucs_model_version = 'v2'
elif data['DemucsModel'] == "Tasnet_extra v2":
demucs_model_set_name = 'tasnet_extra-df3777b2.th'
demucs_model_version = 'v2'
elif data['DemucsModel'] == "Demucs48_hq v2":
demucs_model_set_name = 'demucs48_hq-28a1282c.th'
demucs_model_version = 'v2'
elif data['DemucsModel'] == "Demucs v2":
demucs_model_set_name = 'demucs-e07c671f.th'
demucs_model_version = 'v2'
elif data['DemucsModel'] == "Demucs_extra v2":
demucs_model_set_name = 'demucs_extra-3646af93.th'
demucs_model_version = 'v2'
elif data['DemucsModel'] == "Demucs_unittest v2":
demucs_model_set_name = 'demucs_unittest-09ebc15f.th'
demucs_model_version = 'v2'
elif '.ckpt' in data['DemucsModel'] and 'v2' in data['DemucsModel']:
demucs_model_set_name = data['DemucsModel']
demucs_model_version = 'v2'
elif '.ckpt' in data['DemucsModel'] and 'v1' in data['DemucsModel']:
demucs_model_set_name = data['DemucsModel']
demucs_model_version = 'v1'
elif '.gz' in data['DemucsModel']:
demucs_model_set_name = data['DemucsModel']
demucs_model_version = 'v1'
else:
demucs_model_set_name = data['DemucsModel']
demucs_model_version = 'v3'
try: #Load File(s)
for file_num, music_file in tqdm(enumerate(data['input_paths'], start=1)):
@@ -880,7 +1055,10 @@ def main(window: tk.Wm, text_widget: tk.Text, button_widget: tk.Button, progress
channel_set = int(data['channel'])
margin_set = int(data['margin'])
shift_set = int(data['shifts_b'])
split_mode = data['split_mode']
#print('Split? ', split_mode)
def determinemusicfileFolderName():
"""