-
Notifications
You must be signed in to change notification settings - Fork 0
/
mg_extractor.py
314 lines (251 loc) · 11.8 KB
/
mg_extractor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
# https://github.com/brown-palm/syntheory/blob/main/embeddings/models.py
# https://huggingface.co/docs/transformers/main/model_doc/musicgen
# https://huggingface.co/docs/transformers/main/en/model_doc/encodec#transformers.EncodecFeatureExtractor
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/musicgen/modeling_musicgen.py
import os
import torch
import librosa
import util as um
import util_hf as uhf
import argparse
from transformers import AutoProcessor, MusicgenForConditionalGeneration
from distutils.util import strtobool
# outputs of lm are all tuples where each entry is a tensor
# sizes:
# decoder.hidden_states
# --- large: 1, 200, 2048 (49 hidden states, 0-47: AddBackward0, 48: NativeLayerNormBackward0)
# --- medium: 1, 200, 1536 (49 hidden states, 0-47: AddBackward0, 48: NativeLayerNormBackward0)
# --- small: 1, 200, 1024 (25 hidden states, 0-47: AddBackward0, 48: NativeLayerNormBackward0)
# decoder.attention
# --- large: 1, 32, 200, 200 (48, all ViewBackward0)
# --- medium: 1, 24, 200, 200 (48, all ViewBackward0)
# --- medium: 1, 16, 200, 200 (24, all ViewBackward0)
model_sr = 32000
model_num_layers = {"facebook/musicgen-small": 24, "facebook/musicgen-medium": 48,
"facebook/musicgen-large": 48}
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("-s", "--size", type=str, default="small", help="audio,small, medium, large")
parser.add_argument("-n", "--normalize", type=strtobool, default=True, help="normalize audio")
parser.add_argument("-d", "--debug", type=strtobool, default=False, help="debug")
parser.add_argument("-m", "--meanpool", type=strtobool, default=True, help="to meanpool")
parser.add_argument("-hl", "--save_hidden", type=strtobool, default=True, help="save hidden states")
parser.add_argument("-at", "--save_attn", type=strtobool, default=False, help="save attention")
parser.add_argument("-hf", "--hf_dataset", type=str, default="", help="use old hugging face dataset if passed")
args = parser.parse_args()
model_size = args.size
normalize = args.normalize
debug = args.debug
meanpool = args.meanpool
save_hidden = args.save_hidden
save_attn = args.save_attn
hfds_str = args.hf_dataset
acts_dir = 'acts'
path_list = um.path_list('wav')
use_hf = len(hfds_str) > 0
if use_hf == True:
path_list = uhf.load_syntheory_train_dataset(hfds_str)
acts_dir = 'hf_acts'
text = ""
model_str = "facebook/musicgen-small"
emb_dir = "mg_small"
log_path = "mg_medium"
if model_size == "medium":
model_str = "facebook/musicgen-medium"
emb_dir = "mg_medium"
elif model_size == "large":
model_str = "facebook/musicgen-large"
emb_dir = "mg_large"
elif model_size == "audio":
model_str = "facebook/musicgen-large"
emb_dir = "mg_audio"
num_layers = model_num_layers[model_str]
device = 'cpu'
log_path = emb_dir
if torch.cuda.is_available() == True:
device = 'cuda'
torch.cuda.empty_cache()
torch.set_default_device(device)
if model_size == "audio":
num_layers = -1
out_dir = None
if meanpool == True:
out_dir = um.by_projpath(os.path.join(acts_dir, f'{emb_dir}_mp'), make_dir = True)
else:
out_dir = um.by_projpath(os.path.join(acts_dir, emb_dir), make_dir = True)
log = um.by_projpath(os.path.join('log', f'{log_path}.log'))
#dsamp_rate = 22050
layer_act = 36
dur = 4.0
# https://huggingface.co/docs/transformers/v4.31.0/en/model_doc/musicgen
# ----- language model -----
# decoder_hidden_states ---
# output_hidden_states = True or config.output_hidden_states = True
# tuple of torch.floattensor: (output of embeddings, one output of each layer of following shape)
# shape: (0: batch_size, 1: sequence_length, 2:hidden_size)
# "hidden states of the decoder at the output of each layer + initial embedding inputs"
# decoder_attentions ---
# output_attentions = True or config.output_attentions = True
# tuple of torch.floattensor (one per layer) with followng shape
# shape: batch_size, num_heads, sequence_length, sequence_length
# "attentions weights of decoder, after attention softmax, used to compute the weighted average in the self-attention heads"
# ---- encodec ------
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/musicgen/modeling_musicgen.py
# get_audio_encoder returns self.audio_encoder
#proc = AutoProcessor.from_pretrained(model_str, device_map = device)
#model = MusicgenForConditionalGeneration.from_pretrained(model_str, device_map = device)
proc = AutoProcessor.from_pretrained(model_str)
#model = MusicgenForConditionalGeneration.from_pretrained(model_str)
model = MusicgenForConditionalGeneration.from_pretrained(model_str, device_map=device)
#proc.to(device)
#model.to(device)
model_sr = model.config.audio_encoder.sampling_rate
#layer_acts = [x for x in range(1,73)]
if os.path.isfile(log):
os.remove(log)
with open(log, 'a') as lf:
for fidx,f in enumerate(path_list):
if debug == True:
if fidx > 0: break
dhs = None
dhs_mp = None
dat = None
dat_mp = None
audio = None
procd = None
audio = None
aud_sr = False
if use_hf == False:
audio = um.load_wav(f, dur = dur, normalize = normalize, sr = model_sr, load_dir = load_dir)
else:
audio, aud_sr = uhf.get_from_entry_syntheory_audio(f, mono=True, normalize =normalize, dur = dur)
if aud_sr != model_sr:
audio = librosa.resample(audio, orig_sr=aud_sr, target_sr=model_sr)
if model_size == 'audio':
procd = proc(audio = audio, sampling_rate = model_sr, padding=True, return_tensors = 'pt')
else:
procd = proc(audio = audio, text = text, sampling_rate = model_sr, padding=True, return_tensors = 'pt')
procd.to(device)
outname = None
if use_hf == False:
print(f'loading {f}', file=lf)
outname = um.ext_replace(f, new_ext="pt")
else:
print(f"loading {f['audio']['path']}", file=lf)
outname = um.ext_replace(f['audio']['path'], new_ext="pt")
outpath = os.path.join(out_dir, outname)
if model_size == 'audio':
enc = model.get_audio_encoder()
out = procd['input_values']
# iterating through layers as in original syntheory codebase
# https://github.com/brown-palm/syntheory/blob/main/embeddings/models.py
for layer in enc.encoder.layers:
out = layer(out)
# output shape, (1, 128, 200), where 200 are the timesteps
# so average across timesteps for max pooling
if meanpool == True:
# gives shape (128)
out_mp = torch.mean(out,axis=2).squeeze()
torch.save(out_mp, outpath)
else:
# still need to squeeze
# gives shape (128, 200)
out_save = out.squeeze()
torch.save(out_save, outpath)
if debug == True:
encprops = dir(enc)
iptprops = dir(procd)
print("processed properties", file=lf)
print(iptprops, file=lf)
print("encoder properties", file=lf)
print(encprops, file=lf)
encprops2 = dir(enc.encoder)
print("encoder properties 2", file=lf)
print(encprops2, file=lf)
numlayers = len(enc.encoder.layers)
print(f"num layers: {numlayers}", file=lf)
# going from the audio encder
out2 = enc.encode(**procd)
print("iterating through layers", file=lf)
print(out.shape, file=lf)
print("out2 propeties", file=lf)
out2prop = dir(out2)
print(out2prop, file=lf)
print("out2 outputs: audio_scales",file=lf)
print(out2.audio_scales,file=lf)
print("out2 outputs: audio_codes",file=lf)
print(out2.audio_codes, file=lf)
out3 = enc(**procd)
print("out3 propeties", file=lf)
out3prop = dir(out3)
print(out3prop, file=lf)
out3av = out3['audio_values']
avshape = out3av.shape
print(f"out3 output: audio_values ({avshape})", file=lf)
print(out3av, file=lf)
procd2 = proc(audio = audio, text = text, sampling_rate = model_sr, padding=True, return_tensors = 'pt')
procd2.to(device)
outputs = model(**procd2, output_attentions=True, output_hidden_states=True)
enc_lh = outputs.encoder_last_hidden_state
print("last hidden state of encoder", file=lf)
print(enc_lh.shape, file=lf)
print("iteration output", file=lf)
print(out, file=lf)
print("last hidden state output", file=lf)
print(enc_lh, file=lf)
enc_h = outputs.encoder_hidden_states
enc_at = outputs.encoder_attentions
enc_h_sz = len(enc_h)
enc_at_sz = len(enc_at)
print(f'encoder hidden states: {enc_h_sz}', file=lf)
for i in range(enc_h_sz):
print(f'----{i}----', file=lf)
print(enc_h[i].shape, file=lf)
print(enc_h[i].grad_fn, file=lf)
print(f'encoder hidden states output: {enc_h_sz}', file=lf)
for i in range(enc_h_sz):
print(f'----{i}----', file=lf)
print(enc_h[i], file=lf)
print(f'encoder attentions: {enc_at_sz}', file=lf)
for i in range(enc_at_sz):
print(f'----{i}----', file=lf)
print(enc_at[i].shape, file=lf)
print(enc_at[i].grad_fn, file=lf)
print(f'encoder attentions output: {enc_at_sz}', file=lf)
for i in range(enc_at_sz):
print(f'----{i}----', file=lf)
print(enc_at[i], file=lf)
else:
outputs = model(**procd, output_attentions=True, output_hidden_states=True)
dhs = torch.vstack(outputs.decoder_hidden_states)
dat = torch.vstack(outputs.decoder_attentions)
if meanpool == True:
# gives shape (24/48, 1024/1536/2048)
if save_hidden == True:
dhs_mp = torch.mean(dhs,axis=1)
torch.save(dhs_mp, outpath)
# gives shape (16/24/32)
if save_attn == True:
dat_mp = torch.mean(dat,axis=(2,3))
torch.save(dat_mp, outpath)
else:
# gives shape (24/48, 200, 1024/1536/2048)
if save_hidden == True:
torch.save(dhs, outpath)
# gives shape (16/24/32, 200, 200)
if save_attn == True:
torch.save(dat, outpath)
if debug == True:
dhs_sz = len(dhs)
print(f'hidden states: {dhs_sz}', file=lf)
for i in range(dhs_sz):
print(f'----{i}----', file=lf)
print(dhs[i].shape, file=lf)
print(dhs[i].grad_fn, file=lf)
dat_sz = len(dat)
print(f'attention: {dat_sz}', file=lf)
for i in range(dat_sz):
print(f'----{i}----', file=lf)
print(dat[i].shape, file=lf)
print(dat[i].grad_fn, file=lf)
#torch.save(reps[layer_act], outpath)
#jml.lib.empty_cache()