Add files via upload
This commit is contained in:
109
demucs/test.py
Normal file
109
demucs/test.py
Normal file
@@ -0,0 +1,109 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import gzip
|
||||
import sys
|
||||
from concurrent import futures
|
||||
|
||||
import musdb
|
||||
import museval
|
||||
import torch as th
|
||||
import tqdm
|
||||
from scipy.io import wavfile
|
||||
from torch import distributed
|
||||
|
||||
from .audio import convert_audio
|
||||
from .utils import apply_model
|
||||
|
||||
|
||||
def evaluate(model,
|
||||
musdb_path,
|
||||
eval_folder,
|
||||
workers=2,
|
||||
device="cpu",
|
||||
rank=0,
|
||||
save=False,
|
||||
shifts=0,
|
||||
split=False,
|
||||
overlap=0.25,
|
||||
is_wav=False,
|
||||
world_size=1):
|
||||
"""
|
||||
Evaluate model using museval. Run the model
|
||||
on a single GPU, the bottleneck being the call to museval.
|
||||
"""
|
||||
|
||||
output_dir = eval_folder / "results"
|
||||
output_dir.mkdir(exist_ok=True, parents=True)
|
||||
json_folder = eval_folder / "results/test"
|
||||
json_folder.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
# we load tracks from the original musdb set
|
||||
test_set = musdb.DB(musdb_path, subsets=["test"], is_wav=is_wav)
|
||||
src_rate = 44100 # hardcoded for now...
|
||||
|
||||
for p in model.parameters():
|
||||
p.requires_grad = False
|
||||
p.grad = None
|
||||
|
||||
pendings = []
|
||||
with futures.ProcessPoolExecutor(workers or 1) as pool:
|
||||
for index in tqdm.tqdm(range(rank, len(test_set), world_size), file=sys.stdout):
|
||||
track = test_set.tracks[index]
|
||||
|
||||
out = json_folder / f"{track.name}.json.gz"
|
||||
if out.exists():
|
||||
continue
|
||||
|
||||
mix = th.from_numpy(track.audio).t().float()
|
||||
ref = mix.mean(dim=0) # mono mixture
|
||||
mix = (mix - ref.mean()) / ref.std()
|
||||
mix = convert_audio(mix, src_rate, model.samplerate, model.audio_channels)
|
||||
estimates = apply_model(model, mix.to(device),
|
||||
shifts=shifts, split=split, overlap=overlap)
|
||||
estimates = estimates * ref.std() + ref.mean()
|
||||
|
||||
estimates = estimates.transpose(1, 2)
|
||||
references = th.stack(
|
||||
[th.from_numpy(track.targets[name].audio).t() for name in model.sources])
|
||||
references = convert_audio(references, src_rate,
|
||||
model.samplerate, model.audio_channels)
|
||||
references = references.transpose(1, 2).numpy()
|
||||
estimates = estimates.cpu().numpy()
|
||||
win = int(1. * model.samplerate)
|
||||
hop = int(1. * model.samplerate)
|
||||
if save:
|
||||
folder = eval_folder / "wav/test" / track.name
|
||||
folder.mkdir(exist_ok=True, parents=True)
|
||||
for name, estimate in zip(model.sources, estimates):
|
||||
wavfile.write(str(folder / (name + ".wav")), 44100, estimate)
|
||||
|
||||
if workers:
|
||||
pendings.append((track.name, pool.submit(
|
||||
museval.evaluate, references, estimates, win=win, hop=hop)))
|
||||
else:
|
||||
pendings.append((track.name, museval.evaluate(
|
||||
references, estimates, win=win, hop=hop)))
|
||||
del references, mix, estimates, track
|
||||
|
||||
for track_name, pending in tqdm.tqdm(pendings, file=sys.stdout):
|
||||
if workers:
|
||||
pending = pending.result()
|
||||
sdr, isr, sir, sar = pending
|
||||
track_store = museval.TrackStore(win=44100, hop=44100, track_name=track_name)
|
||||
for idx, target in enumerate(model.sources):
|
||||
values = {
|
||||
"SDR": sdr[idx].tolist(),
|
||||
"SIR": sir[idx].tolist(),
|
||||
"ISR": isr[idx].tolist(),
|
||||
"SAR": sar[idx].tolist()
|
||||
}
|
||||
|
||||
track_store.add_target(target_name=target, values=values)
|
||||
json_path = json_folder / f"{track_name}.json.gz"
|
||||
gzip.open(json_path, "w").write(track_store.json.encode('utf-8'))
|
||||
if world_size > 1:
|
||||
distributed.barrier()
|
||||
Reference in New Issue
Block a user