Add files via upload

This commit is contained in:
Anjok07
2022-08-23 15:04:39 -05:00
committed by GitHub
parent f7e37b492b
commit 7778d4c3fb
19 changed files with 1735 additions and 581 deletions

View File

@@ -10,11 +10,13 @@ inteprolation between chunks, as well as the "shift trick".
from concurrent.futures import ThreadPoolExecutor
import random
import typing as tp
from multiprocessing import Process,Queue,Pipe
import torch as th
from torch import nn
from torch.nn import functional as F
import tqdm
import tkinter as tk
from .demucs import Demucs
from .hdemucs import HDemucs
@@ -22,6 +24,7 @@ from .utils import center_trim, DummyPoolExecutor
Model = tp.Union[Demucs, HDemucs]
progress_bar_num = 0
class BagOfModels(nn.Module):
def __init__(self, models: tp.List[Model],
@@ -107,7 +110,6 @@ class TensorChunk:
assert out.shape[-1] == target_length
return out
def tensor_chunk(tensor_or_chunk):
if isinstance(tensor_or_chunk, TensorChunk):
return tensor_or_chunk
@@ -115,10 +117,9 @@ def tensor_chunk(tensor_or_chunk):
assert isinstance(tensor_or_chunk, th.Tensor)
return TensorChunk(tensor_or_chunk)
def apply_model(model, mix, shifts=1, split=True,
overlap=0.25, transition_power=1., progress=False, device=None,
num_workers=0, pool=None):
def apply_model(model, mix, gui_progress_bar: tk.Variable, widget_text: tk.Text, update_prog, total_files, file_num, inference_type, shifts=1, split=True,
overlap=0.25, transition_power=1., progress=True, device=None,
num_workers=0, pool=None, segmen=False):
"""
Apply model to a given mixture.
@@ -136,6 +137,12 @@ def apply_model(model, mix, shifts=1, split=True,
When `device` is different from `mix.device`, only local computations will
be on `device`, while the entire tracks will be stored on `mix.device`.
"""
base_text = 'File {file_num}/{total_files} '.format(file_num=file_num,
total_files=total_files)
global fut_length
if device is None:
device = mix.device
else:
@@ -145,7 +152,12 @@ def apply_model(model, mix, shifts=1, split=True,
pool = ThreadPoolExecutor(num_workers)
else:
pool = DummyPoolExecutor()
kwargs = {
'gui_progress_bar': gui_progress_bar,
'widget_text': widget_text,
'update_prog': update_prog,
'segmen': segmen,
'shifts': shifts,
'split': split,
'overlap': overlap,
@@ -153,17 +165,35 @@ def apply_model(model, mix, shifts=1, split=True,
'progress': progress,
'device': device,
'pool': pool,
'total_files': total_files,
'file_num': file_num,
'inference_type': inference_type
}
if isinstance(model, BagOfModels):
# Special treatment for bag of model.
# We explicitely apply multiple times `apply_model` so that the random shifts
# are different for each model.
global bag_num
global current_model
global progress_bar
global prog_bar
#global percent_prog_del
#percent_prog_del = gui_progress_bar.get()
progress_bar = 0
prog_bar = 0
estimates = 0
totals = [0] * len(model.sources)
bag_num = len(model.models)
fut_length = 0
current_model = 0 #(bag_num + 1)
for sub_model, weight in zip(model.models, model.weights):
original_model_device = next(iter(sub_model.parameters())).device
sub_model.to(device)
fut_length += fut_length
current_model += 1
out = apply_model(sub_model, mix, **kwargs)
sub_model.to(original_model_device)
for k, inst_weight in enumerate(weight):
@@ -179,6 +209,7 @@ def apply_model(model, mix, shifts=1, split=True,
model.to(device)
assert transition_power >= 1, "transition_power < 1 leads to weird behavior."
batch, channels, length = mix.shape
if split:
kwargs['split'] = False
out = th.zeros(batch, len(model.sources), channels, length, device=mix.device)
@@ -202,9 +233,26 @@ def apply_model(model, mix, shifts=1, split=True,
future = pool.submit(apply_model, model, chunk, **kwargs)
futures.append((future, offset))
offset += segment
if progress:
futures = tqdm.tqdm(futures, unit_scale=scale, ncols=120, unit='seconds')
for future, offset in futures:
if segmen:
fut_length = len(futures)
full_fut_length = (fut_length * bag_num)
send_back = full_fut_length * 2
progress_bar += 100
prog_bar += 1
full_step = (progress_bar / full_fut_length)
percent_prog = f"{base_text}Demucs Inference Progress: {prog_bar}/{full_fut_length} | {round(full_step)}%"
if inference_type == 'demucs_only':
update_prog(gui_progress_bar, total_files, file_num,
step=(0.1 + (1.7/send_back * prog_bar)))
elif inference_type == 'inference_mdx':
update_prog(gui_progress_bar, total_files, file_num,
step=(0.35 + (1.05/send_back * prog_bar)))
elif inference_type == 'inference_vr':
update_prog(gui_progress_bar, total_files, file_num,
step=(0.6 + (0.7/send_back * prog_bar)))
widget_text.percentage(percent_prog)
#gui_progress_bar.set(step)
chunk_out = future.result()
chunk_length = chunk_out.shape[-1]
out[..., offset:offset + segment] += (weight[:chunk_length] * chunk_out).to(mix.device)