-
Notifications
You must be signed in to change notification settings - Fork 6
/
nsynth_subset.py
44 lines (39 loc) · 1.94 KB
/
nsynth_subset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
"""
Filter out the subset used for neural audio synthesis projects by Magenta.
Should still have slight difference with the actual dataset (as the numbers don't agree)
Alternatively, you can directly download NSynth subset from TF dataset, instead of
downloading the whole dataset and filter with this script.
"""
import json
from tqdm import tqdm
# Follow instructions here: https://www.tensorflow.org/datasets/catalog/nsynth#nsynthgansynth_subset
# to filter out the subset
with open("E:/nsynth/nsynth-train/examples.json") as f:
dic_train = json.load(f)
with open("E:/nsynth/nsynth-valid/examples.json") as f:
dic_valid = json.load(f)
with open("E:/nsynth/nsynth-test/examples.json") as f:
dic_test = json.load(f)
keys_train, keys_valid, keys_test = [], [], []
for key in tqdm(dic_train):
if dic_train[key]["pitch"] >= 24 and dic_train[key]["pitch"] <= 84 and dic_train[key]["instrument_source_str"] == "acoustic":
keys_train.append(key)
for key in tqdm(dic_valid):
if dic_valid[key]["pitch"] >= 24 and dic_valid[key]["pitch"] <= 84 and dic_valid[key]["instrument_source_str"] == "acoustic":
keys_valid.append(key)
for key in tqdm(dic_test):
if dic_test[key]["pitch"] >= 24 and dic_test[key]["pitch"] <= 84 and dic_test[key]["instrument_source_str"] == "acoustic":
keys_test.append(key)
# in total, there are 86,775 samples across all splits, but we keep the orignal split instead of using
# the alternate split mentioned here: https://www.tensorflow.org/datasets/catalog/nsynth#nsynthgansynth_subset
# so the composition is different
print(len(keys_train), len(keys_valid), len(keys_test))
with open("keys_train.txt", "w+") as f:
for key in keys_train:
f.write(key + "\n")
with open("keys_valid.txt", "w+") as f:
for key in keys_valid:
f.write(key + "\n")
with open("keys_test.txt", "w+") as f:
for key in keys_test:
f.write(key + "\n")