Add files via upload
This commit is contained in:
115
demucs/compressed.py
Normal file
115
demucs/compressed.py
Normal file
@@ -0,0 +1,115 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import json
|
||||
from fractions import Fraction
|
||||
from concurrent import futures
|
||||
|
||||
import musdb
|
||||
from torch import distributed
|
||||
|
||||
from .audio import AudioFile
|
||||
|
||||
|
||||
def get_musdb_tracks(root, *args, **kwargs):
|
||||
mus = musdb.DB(root, *args, **kwargs)
|
||||
return {track.name: track.path for track in mus}
|
||||
|
||||
|
||||
class StemsSet:
|
||||
def __init__(self, tracks, metadata, duration=None, stride=1,
|
||||
samplerate=44100, channels=2, streams=slice(None)):
|
||||
|
||||
self.metadata = []
|
||||
for name, path in tracks.items():
|
||||
meta = dict(metadata[name])
|
||||
meta["path"] = path
|
||||
meta["name"] = name
|
||||
self.metadata.append(meta)
|
||||
if duration is not None and meta["duration"] < duration:
|
||||
raise ValueError(f"Track {name} duration is too small {meta['duration']}")
|
||||
self.metadata.sort(key=lambda x: x["name"])
|
||||
self.duration = duration
|
||||
self.stride = stride
|
||||
self.channels = channels
|
||||
self.samplerate = samplerate
|
||||
self.streams = streams
|
||||
|
||||
def __len__(self):
|
||||
return sum(self._examples_count(m) for m in self.metadata)
|
||||
|
||||
def _examples_count(self, meta):
|
||||
if self.duration is None:
|
||||
return 1
|
||||
else:
|
||||
return int((meta["duration"] - self.duration) // self.stride + 1)
|
||||
|
||||
def track_metadata(self, index):
|
||||
for meta in self.metadata:
|
||||
examples = self._examples_count(meta)
|
||||
if index >= examples:
|
||||
index -= examples
|
||||
continue
|
||||
return meta
|
||||
|
||||
def __getitem__(self, index):
|
||||
for meta in self.metadata:
|
||||
examples = self._examples_count(meta)
|
||||
if index >= examples:
|
||||
index -= examples
|
||||
continue
|
||||
streams = AudioFile(meta["path"]).read(seek_time=index * self.stride,
|
||||
duration=self.duration,
|
||||
channels=self.channels,
|
||||
samplerate=self.samplerate,
|
||||
streams=self.streams)
|
||||
return (streams - meta["mean"]) / meta["std"]
|
||||
|
||||
|
||||
def _get_track_metadata(path):
|
||||
# use mono at 44kHz as reference. For any other settings data won't be perfectly
|
||||
# normalized but it should be good enough.
|
||||
audio = AudioFile(path)
|
||||
mix = audio.read(streams=0, channels=1, samplerate=44100)
|
||||
return {"duration": audio.duration, "std": mix.std().item(), "mean": mix.mean().item()}
|
||||
|
||||
|
||||
def _build_metadata(tracks, workers=10):
|
||||
pendings = []
|
||||
with futures.ProcessPoolExecutor(workers) as pool:
|
||||
for name, path in tracks.items():
|
||||
pendings.append((name, pool.submit(_get_track_metadata, path)))
|
||||
return {name: p.result() for name, p in pendings}
|
||||
|
||||
|
||||
def _build_musdb_metadata(path, musdb, workers):
|
||||
tracks = get_musdb_tracks(musdb)
|
||||
metadata = _build_metadata(tracks, workers)
|
||||
path.parent.mkdir(exist_ok=True, parents=True)
|
||||
json.dump(metadata, open(path, "w"))
|
||||
|
||||
|
||||
def get_compressed_datasets(args, samples):
|
||||
metadata_file = args.metadata / "musdb.json"
|
||||
if not metadata_file.is_file() and args.rank == 0:
|
||||
_build_musdb_metadata(metadata_file, args.musdb, args.workers)
|
||||
if args.world_size > 1:
|
||||
distributed.barrier()
|
||||
metadata = json.load(open(metadata_file))
|
||||
duration = Fraction(samples, args.samplerate)
|
||||
stride = Fraction(args.data_stride, args.samplerate)
|
||||
train_set = StemsSet(get_musdb_tracks(args.musdb, subsets=["train"], split="train"),
|
||||
metadata,
|
||||
duration=duration,
|
||||
stride=stride,
|
||||
streams=slice(1, None),
|
||||
samplerate=args.samplerate,
|
||||
channels=args.audio_channels)
|
||||
valid_set = StemsSet(get_musdb_tracks(args.musdb, subsets=["train"], split="valid"),
|
||||
metadata,
|
||||
samplerate=args.samplerate,
|
||||
channels=args.audio_channels)
|
||||
return train_set, valid_set
|
||||
Reference in New Issue
Block a user