Add files via upload
This commit is contained in:
139
inference_v5.py
139
inference_v5.py
@@ -103,7 +103,7 @@ def determineModelFolderName():
|
||||
def main(window: tk.Wm, text_widget: tk.Text, button_widget: tk.Button, progress_var: tk.Variable,
|
||||
**kwargs: dict):
|
||||
|
||||
global model_params_d
|
||||
global gui_progress_bar
|
||||
global nn_arch_sizes
|
||||
global nn_architecture
|
||||
|
||||
@@ -115,9 +115,10 @@ def main(window: tk.Wm, text_widget: tk.Text, button_widget: tk.Button, progress
|
||||
|
||||
global flac_type_set
|
||||
global mp3_bit_set
|
||||
global space
|
||||
|
||||
wav_type_set = data['wavtype']
|
||||
|
||||
gui_progress_bar = progress_var
|
||||
#Error Handling
|
||||
|
||||
runtimeerr = "CUDNN error executing cudnnSetTensorNdDescriptor"
|
||||
@@ -127,6 +128,7 @@ def main(window: tk.Wm, text_widget: tk.Text, button_widget: tk.Button, progress
|
||||
file_err = "FileNotFoundError"
|
||||
ffmp_err = """audioread\__init__.py", line 116, in audio_open"""
|
||||
sf_write_err = "sf.write"
|
||||
demucs_model_missing_err = "is neither a single pre-trained model or a bag of models."
|
||||
|
||||
try:
|
||||
with open('errorlog.txt', 'w') as f:
|
||||
@@ -382,8 +384,12 @@ def main(window: tk.Wm, text_widget: tk.Text, button_widget: tk.Button, progress
|
||||
global default_window_size
|
||||
global default_agg
|
||||
global normalization_set
|
||||
global update_prog
|
||||
|
||||
update_prog = update_progress
|
||||
default_window_size = data['window_size']
|
||||
default_agg = data['agg']
|
||||
space = ' '*90
|
||||
|
||||
stime = time.perf_counter()
|
||||
progress_var.set(0)
|
||||
@@ -432,6 +438,9 @@ def main(window: tk.Wm, text_widget: tk.Text, button_widget: tk.Button, progress
|
||||
else:
|
||||
base_name = f'{data["export_path"]}/{file_num}_{os.path.splitext(os.path.basename(music_file))[0]}'
|
||||
|
||||
global inference_type
|
||||
|
||||
inference_type = 'inference_vr'
|
||||
model_name = os.path.basename(data[f'{data["useModel"]}Model'])
|
||||
model = vocal_remover.models[data['useModel']]
|
||||
device = vocal_remover.devices[data['useModel']]
|
||||
@@ -441,6 +450,8 @@ def main(window: tk.Wm, text_widget: tk.Text, button_widget: tk.Button, progress
|
||||
progress_kwargs = {'progress_var': progress_var,
|
||||
'total_files': len(data['input_paths']),
|
||||
'file_num': file_num}
|
||||
progress_demucs_kwargs = {'total_files': len(data['input_paths']),
|
||||
'file_num': file_num, 'inference_type': inference_type}
|
||||
update_progress(**progress_kwargs,
|
||||
step=0)
|
||||
|
||||
@@ -503,7 +514,7 @@ def main(window: tk.Wm, text_widget: tk.Text, button_widget: tk.Button, progress
|
||||
model_hash = hashlib.md5(open(ModelName,'rb').read()).hexdigest()
|
||||
model_params = []
|
||||
model_params = lib_v5.filelist.provide_model_param_hash(model_hash)
|
||||
print(model_params)
|
||||
#print(model_params)
|
||||
if model_params[0] == 'Not Found Using Hash':
|
||||
model_params = []
|
||||
model_params = lib_v5.filelist.provide_model_param_name(ModelName)
|
||||
@@ -622,8 +633,6 @@ def main(window: tk.Wm, text_widget: tk.Text, button_widget: tk.Button, progress
|
||||
text_widget.write(base_text + 'Loading the stft of audio source...')
|
||||
|
||||
text_widget.write(' Done!\n')
|
||||
|
||||
text_widget.write(base_text + "Please Wait...\n")
|
||||
|
||||
X_spec_m = spec_utils.combine_spectrograms(X_spec_s, mp)
|
||||
|
||||
@@ -631,22 +640,47 @@ def main(window: tk.Wm, text_widget: tk.Text, button_widget: tk.Button, progress
|
||||
|
||||
def inference(X_spec, device, model, aggressiveness):
|
||||
|
||||
def _execute(X_mag_pad, roi_size, n_window, device, model, aggressiveness):
|
||||
def _execute(X_mag_pad, roi_size, n_window, device, model, aggressiveness, tta=False):
|
||||
model.eval()
|
||||
|
||||
global active_iterations
|
||||
global progress_value
|
||||
|
||||
with torch.no_grad():
|
||||
preds = []
|
||||
|
||||
iterations = [n_window]
|
||||
|
||||
total_iterations = sum(iterations)
|
||||
|
||||
text_widget.write(base_text + "Processing "f"{total_iterations} Slices... ")
|
||||
if data['tta']:
|
||||
total_iterations = sum(iterations)
|
||||
total_iterations = total_iterations*2
|
||||
else:
|
||||
total_iterations = sum(iterations)
|
||||
|
||||
if tta:
|
||||
active_iterations = sum(iterations)
|
||||
active_iterations = active_iterations - 2
|
||||
total_iterations = total_iterations - 2
|
||||
else:
|
||||
active_iterations = 0
|
||||
|
||||
for i in tqdm(range(n_window)):
|
||||
update_progress(**progress_kwargs,
|
||||
step=(0.1 + (0.8/n_window * i)))
|
||||
progress_bar = 0
|
||||
for i in range(n_window):
|
||||
active_iterations += 1
|
||||
if data['demucsmodelVR']:
|
||||
update_progress(**progress_kwargs,
|
||||
step=(0.1 + (0.5/total_iterations * active_iterations)))
|
||||
else:
|
||||
update_progress(**progress_kwargs,
|
||||
step=(0.1 + (0.8/total_iterations * active_iterations)))
|
||||
start = i * roi_size
|
||||
progress_bar += 100
|
||||
progress_value = progress_bar
|
||||
active_iterations_step = active_iterations*100
|
||||
step = (active_iterations_step / total_iterations)
|
||||
|
||||
percent_prog = f"{base_text}Inference Progress: {active_iterations}/{total_iterations} | {round(step)}%"
|
||||
text_widget.percentage(percent_prog)
|
||||
X_mag_window = X_mag_pad[None, :, :, start:start + data['window_size']]
|
||||
X_mag_window = torch.from_numpy(X_mag_window).to(device)
|
||||
|
||||
@@ -656,7 +690,6 @@ def main(window: tk.Wm, text_widget: tk.Text, button_widget: tk.Button, progress
|
||||
preds.append(pred[0])
|
||||
|
||||
pred = np.concatenate(preds, axis=2)
|
||||
text_widget.write('Done!\n')
|
||||
return pred
|
||||
|
||||
def preprocess(X_spec):
|
||||
@@ -691,7 +724,7 @@ def main(window: tk.Wm, text_widget: tk.Text, button_widget: tk.Button, progress
|
||||
X_mag_pre, ((0, 0), (0, 0), (pad_l, pad_r)), mode='constant')
|
||||
|
||||
pred_tta = _execute(X_mag_pad, roi_size, n_window,
|
||||
device, model, aggressiveness)
|
||||
device, model, aggressiveness, tta=True)
|
||||
pred_tta = pred_tta[:, :, roi_size // 2:]
|
||||
pred_tta = pred_tta[:, :, :n_frame]
|
||||
|
||||
@@ -702,17 +735,16 @@ def main(window: tk.Wm, text_widget: tk.Text, button_widget: tk.Button, progress
|
||||
aggressiveness = {'value': aggresive_set, 'split_bin': mp.param['band'][1]['crop_stop']}
|
||||
|
||||
if data['tta']:
|
||||
text_widget.write(base_text + "Running Inferences (TTA)...\n")
|
||||
text_widget.write(base_text + f"Running Inferences (TTA)... {space}\n")
|
||||
else:
|
||||
text_widget.write(base_text + "Running Inference...\n")
|
||||
text_widget.write(base_text + f"Running Inference... {space}\n")
|
||||
|
||||
pred, X_mag, X_phase = inference(X_spec_m,
|
||||
device,
|
||||
model, aggressiveness)
|
||||
|
||||
update_progress(**progress_kwargs,
|
||||
step=0.9)
|
||||
# Postprocess
|
||||
text_widget.write('\n')
|
||||
|
||||
if data['postprocess']:
|
||||
try:
|
||||
text_widget.write(base_text + 'Post processing...')
|
||||
@@ -743,19 +775,38 @@ def main(window: tk.Wm, text_widget: tk.Text, button_widget: tk.Button, progress
|
||||
v_spec_m = X_spec_m - y_spec_m
|
||||
|
||||
def demix_demucs(mix):
|
||||
#print('shift_set ', shift_set)
|
||||
text_widget.write(base_text + "Running Demucs Inference...\n")
|
||||
text_widget.write(base_text + "Processing... ")
|
||||
|
||||
print(' Running Demucs Inference...')
|
||||
|
||||
if split_mode:
|
||||
text_widget.write(base_text + f'Running Demucs Inference... {space}')
|
||||
else:
|
||||
text_widget.write(base_text + f'Running Demucs Inference... ')
|
||||
|
||||
mix = torch.tensor(mix, dtype=torch.float32)
|
||||
ref = mix.mean(0)
|
||||
mix = (mix - ref.mean()) / ref.std()
|
||||
|
||||
widget_text = text_widget
|
||||
with torch.no_grad():
|
||||
sources = apply_model(demucs, mix[None], split=split_mode, device=device, overlap=overlap_set, shifts=shift_set, progress=False)[0]
|
||||
|
||||
text_widget.write('Done!\n')
|
||||
sources = apply_model(demucs,
|
||||
mix[None],
|
||||
gui_progress_bar,
|
||||
widget_text,
|
||||
update_prog,
|
||||
split=split_mode,
|
||||
device=device,
|
||||
overlap=overlap_set,
|
||||
shifts=shift_set,
|
||||
progress=False,
|
||||
segmen=True,
|
||||
**progress_demucs_kwargs)[0]
|
||||
|
||||
if split_mode:
|
||||
text_widget.write('\n')
|
||||
else:
|
||||
update_progress(**progress_kwargs,
|
||||
step=0.9)
|
||||
text_widget.write('Done!\n')
|
||||
|
||||
sources = (sources * ref.std() + ref.mean()).cpu().numpy()
|
||||
sources[[0,1]] = sources[[1,0]]
|
||||
@@ -774,15 +825,9 @@ def main(window: tk.Wm, text_widget: tk.Text, button_widget: tk.Button, progress
|
||||
|
||||
if data['demucsmodelVR']:
|
||||
demucs = HDemucs(sources=["other", "vocals"])
|
||||
text_widget.write(base_text + 'Loading Demucs model... ')
|
||||
update_progress(**progress_kwargs,
|
||||
step=0.95)
|
||||
path_d = Path('models/Demucs_Models/v3_repo')
|
||||
#print('What Demucs model was chosen? ', demucs_model_set)
|
||||
demucs = _gm(name=demucs_model_set, repo=path_d)
|
||||
text_widget.write('Done!\n')
|
||||
|
||||
#print('segment: ', data['segment'])
|
||||
|
||||
if data['segment'] == 'None':
|
||||
segment = None
|
||||
@@ -803,7 +848,7 @@ def main(window: tk.Wm, text_widget: tk.Text, button_widget: tk.Button, progress
|
||||
else:
|
||||
if segment is not None:
|
||||
sub.segment = segment
|
||||
text_widget.write(base_text + "Segments set to "f"{segment}.\n")
|
||||
#text_widget.write(base_text + "Segments set to "f"{segment}.\n")
|
||||
except:
|
||||
segment = None
|
||||
if isinstance(demucs, BagOfModels):
|
||||
@@ -814,8 +859,6 @@ def main(window: tk.Wm, text_widget: tk.Text, button_widget: tk.Button, progress
|
||||
if segment is not None:
|
||||
sub.segment = segment
|
||||
|
||||
#print('segment port-process: ', segment)
|
||||
|
||||
demucs.cpu()
|
||||
demucs.eval()
|
||||
|
||||
@@ -1039,7 +1082,7 @@ def main(window: tk.Wm, text_widget: tk.Text, button_widget: tk.Button, progress
|
||||
text_widget.write(f'\nError Received:\n\n')
|
||||
text_widget.write(f'Could not write audio file.\n')
|
||||
text_widget.write(f'This could be due to low storage on target device or a system permissions issue.\n')
|
||||
text_widget.write(f"\nFor raw error details, go to the Error Log tab in the Help Guide.\n")
|
||||
text_widget.write(f"\nGo to the Settings Menu and click \"Open Error Log\" for raw error details.\n")
|
||||
text_widget.write(f'\nIf the error persists, please contact the developers.\n\n')
|
||||
text_widget.write(f'Time Elapsed: {time.strftime("%H:%M:%S", time.gmtime(int(time.perf_counter() - stime)))}')
|
||||
try:
|
||||
@@ -1084,6 +1127,28 @@ def main(window: tk.Wm, text_widget: tk.Text, button_widget: tk.Button, progress
|
||||
button_widget.configure(state=tk.NORMAL) # Enable Button
|
||||
return
|
||||
|
||||
if demucs_model_missing_err in message:
|
||||
text_widget.write("\n" + base_text + f'Separation failed for the following audio file:\n')
|
||||
text_widget.write(base_text + f'"{os.path.basename(music_file)}"\n')
|
||||
text_widget.write(f'\nError Received:\n\n')
|
||||
text_widget.write(f'The selected Demucs model is missing.\n\n')
|
||||
text_widget.write(f'Please download the model or make sure it is in the correct directory.\n\n')
|
||||
text_widget.write(f'Time Elapsed: {time.strftime("%H:%M:%S", time.gmtime(int(time.perf_counter() - stime)))}')
|
||||
try:
|
||||
with open('errorlog.txt', 'w') as f:
|
||||
f.write(f'Last Error Received:\n\n' +
|
||||
f'Error Received while processing "{os.path.basename(music_file)}":\n' +
|
||||
f'Process Method: VR Architecture\n\n' +
|
||||
f'The selected Demucs model is missing.\n\n' +
|
||||
f'Please download the model or make sure it is in the correct directory.\n\n' +
|
||||
message + f'\nError Time Stamp [{datetime.now().strftime("%Y-%m-%d %H:%M:%S")}]\n')
|
||||
except:
|
||||
pass
|
||||
torch.cuda.empty_cache()
|
||||
progress_var.set(0)
|
||||
button_widget.configure(state=tk.NORMAL) # Enable Button
|
||||
return
|
||||
|
||||
print(traceback_text)
|
||||
print(type(e).__name__, e)
|
||||
print(message)
|
||||
@@ -1103,7 +1168,7 @@ def main(window: tk.Wm, text_widget: tk.Text, button_widget: tk.Button, progress
|
||||
text_widget.write("\n" + base_text + f'Separation failed for the following audio file:\n')
|
||||
text_widget.write(base_text + f'"{os.path.basename(music_file)}"\n')
|
||||
text_widget.write(f'\nError Received:\n')
|
||||
text_widget.write("\nFor raw error details, go to the Error Log tab in the Help Guide.\n")
|
||||
text_widget.write("\Go to the Settings Menu and click \"Open Error Log\" for raw error details.\n")
|
||||
text_widget.write("\n" + f'Please address the error and try again.' + "\n")
|
||||
text_widget.write(f'If this error persists, please contact the developers with the error details.\n\n')
|
||||
text_widget.write(f'Time Elapsed: {time.strftime("%H:%M:%S", time.gmtime(int(time.perf_counter() - stime)))}')
|
||||
|
||||
Reference in New Issue
Block a user