-
Notifications
You must be signed in to change notification settings - Fork 23
/
SingleSongTest.py
executable file
·95 lines (75 loc) · 3.3 KB
/
SingleSongTest.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = '2'
import argparse
import numpy as np
import h5py
from project.Feature.FeatureFirstLayer import feature_extraction
from project.Feature.FeatureSecondLayer import fetch_harmonic
from project.Predict import predict_v1
from project.postprocess import MultiPostProcess
from project.utils import ModelInfo
from project.configuration import MusicNet_Instruments, HarmonicNum
def create_parser():
parser = argparse.ArgumentParser(description="Transcribe on the given audio.")
parser.add_argument("-i", "--input-audio",
help="Path to the input audio you want to transcribe",
type=str)
parser.add_argument("-m", "--model-path",
help="Path to the pre-trained model.",
type=str)
parser.add_argument("-o", "--output-fig-name",
help="Name of transcribed figure of piano roll to save.",
type=str, default="Piano Roll")
parser.add_argument("--to-midi", help="Also output the transcription result to midi file.",
type=str)
parser.add_argument("--onset-th", help="Onset threshold (5~8)", type=float)
return parser
def main():
parser = create_parser()
args = parser.parse_args()
# Pre-process features
assert(os.path.isfile(args.input_audio)), f"The given path is not a file!. Please check your input again. Given input: {args.input_audio}"
print("Processing features of input audio: {}".format(args.input_audio))
Z, tfrL0, tfrLF, tfrLQ, t, cenf, f = feature_extraction(args.input_audio)
# Load pre-trained model
minfo = ModelInfo()
model = minfo.load_model(args.model_path)
minfo.onset_th = minfo.onset_th if args.onset_th is None else args.onset_th
print(minfo)
# Post-process feature according to the configuration of model
if minfo.feature_type == "HCFP":
assert(len(minfo.input_channels) == (HarmonicNum*2+2))
spec = []
ceps = []
for i in range(HarmonicNum+1):
spec.append(fetch_harmonic(tfrL0, cenf, i))
ceps.append(fetch_harmonic(tfrLQ, cenf, i))
spec = np.transpose(np.array(spec), axes=(2, 1, 0))
ceps = np.transpose(np.array(ceps), axes=(2, 1, 0))
feature = np.dstack((spec, ceps))
else:
assert(len(minfo.input_channels) <= 4)
feature = np.array([Z, tfrL0, tfrLF, tfrLQ])
feature = np.transpose(feature, axes=(2, 1, 0))
print("Predicting...")
pred = predict_v1(feature[:,:,minfo.input_channels], model, minfo.timesteps, batch_size=4)
mode_mapping = {
"frame": "true_frame",
"frame_onset": "note",
"multi_instrument_frame": "true_frame",
"multi_instrument_note": "note"
}
midi = MultiPostProcess(
pred,
mode=mode_mapping[minfo.label_type],
onset_th=minfo.onset_th,
dura_th=minfo.dura_th,
frm_th=minfo.frm_th,
inst_th=minfo.inst_th,
t_unit=0.02
)
if args.to_midi is not None:
midi.write(args.to_midi)
print("Midi written as {}".format(args.to_midi))
if __name__ == "__main__":
main()