-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
🚨 Add training compatibility for Musicgen-like models #29802
Conversation
Co-authored-by: Sanchit Gandhi <[email protected]>
Hi! Is it possible to finetune the musicgen model currently? If yes then is there something I should keep in mind. Would be really helpful if you could share your opinions. Thanks! |
Wonderful work! I'm currently attempting to fine-tune the Musicgen model using these codes, but I haven't succeeded yet. Is the model ready for fine-tuning, and are there specific aspects I should be aware of? Any training tips or guidance you could provide would be greatly appreciated! Thank you so much! |
Hey @arjunsinghrathore and @LiuZH-19, I'll likely release some fine-tuning code next week or the week after! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fast tests look quite rigorous, but we can probably use the base one unless there's a specific embedding issue! Otherwise LGTM
# per codebook cross-entropy | ||
# -100 labels are ignored | ||
labels = labels.masked_fill(labels == self.config.pad_token_id, -100) | ||
|
||
mask = labels != -100 | ||
|
||
# per codebook cross-entropy | ||
for codebook in range(self.config.num_codebooks): | ||
codebook_logits = logits[:, codebook].contiguous().view(-1, logits.shape[-1]) | ||
codebook_mask = mask[..., codebook].contiguous().view(-1) | ||
codebook_labels = labels[..., codebook].contiguous().view(-1) | ||
|
||
loss += loss_fct(codebook_logits[codebook_mask], codebook_labels[codebook_mask]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The padded labels should be set to -100
outside of the modelling code, i.e. in the data collator. Then, we don't need any of this masking logic, since the CE loss masks out -100
values by default (see argument ignore_index
)
# per codebook cross-entropy | |
# -100 labels are ignored | |
labels = labels.masked_fill(labels == self.config.pad_token_id, -100) | |
mask = labels != -100 | |
# per codebook cross-entropy | |
for codebook in range(self.config.num_codebooks): | |
codebook_logits = logits[:, codebook].contiguous().view(-1, logits.shape[-1]) | |
codebook_mask = mask[..., codebook].contiguous().view(-1) | |
codebook_labels = labels[..., codebook].contiguous().view(-1) | |
loss += loss_fct(codebook_logits[codebook_mask], codebook_labels[codebook_mask]) | |
# per codebook cross-entropy | |
for codebook in range(self.config.num_codebooks): | |
codebook_logits = logits[:, codebook].contiguous().view(-1, logits.shape[-1]) | |
codebook_labels = labels[..., codebook].contiguous().view(-1) | |
loss += loss_fct(codebook_logits, codebook_labels) |
|
||
loss += loss_fct(codebook_logits[codebook_mask], codebook_labels[codebook_mask]) | ||
|
||
loss = loss / self.config.num_codebooks |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reference from original AudioCraft code: https://github.com/facebookresearch/audiocraft/blob/69fea8b290ad1b4b40d28f92d1dfc0ab01dbab85/audiocraft/solvers/musicgen.py#L242-L243
Interesting that they average over codebooks, and not a true average over all labels. Given the sequence length for music generation is large (1500 tokens), the difference is going to be negligible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's add this link in a comment above for any unsuspecting future code reader to have context
@@ -1340,15 +1343,22 @@ def forward( | |||
return_dict: Optional[bool] = None, | |||
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: | |||
r""" | |||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): | |||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks)`, *optional*): | |||
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set | |||
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
c.f. this docstring that explains how the -100
padding token should be set
# Contrarily to the initial method, we don't unfreeze freezed parameters. | ||
# Otherwise, it'll mess with the freezed sinusoidal embeddings |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It'll change the weights of the pre-trained embeddings for sure, but the code should still run right? Or is there an issue where the code won't run if we train the embeddings? Unless this is the case, I would just use the base check_training_gradient_checkpointing
method for simplicity
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comments are not really clear, I'll make it clearer:
- the ConditionalGeneration models have the audio encoder which is never used outside of the
.generate
, if we don't freeze it, we'll have some issues because trainable weights won't have seen any gradients. - the CausalModel's sinusoidal embeddings are frozen, and should stay frozen, (it shouldn't have been transcribed in Parameter)
self.weights.requires_grad = False
src/transformers/models/musicgen_melody/modeling_musicgen_melody.py
Outdated
Show resolved
Hide resolved
Hey @amyeroberts, gentle ping to ask for a review! Many thanks for your help! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this capability!
Mostly small comments. Only concern is backwards compatibility wrt the loss for musicgen
""" | ||
|
||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | ||
|
||
if (labels is not None) and (input_ids is None and inputs_embeds is None): | ||
input_ids = shift_tokens_right(labels, self.config.pad_token_id, self.config.bos_token_id) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a bit funny to do this on this input_ids - normally it's just on the decoder_input_ids.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, it's a bit confusing but input_ids
in MusicgenForCausalLM actually corresponds to the audio input_ids (i.e the decoder input ids)!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, ok! Makes sense :)
|
||
loss += loss_fct(codebook_logits[codebook_mask], codebook_labels[codebook_mask]) | ||
|
||
loss = loss / self.config.num_codebooks |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's add this link in a comment above for any unsuspecting future code reader to have context
|
||
return Seq2SeqLMOutput( | ||
loss=loss, | ||
loss=decoder_outputs.loss, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmmm, the problem with this is it's not backwards compatible - users are now going to get different values of loss than before
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't seen any users using this or mentioning this tbh (besides it's totally wrong!).
How should we best handle this ? maybe adding a breaking flag in the PR name ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it's completely wrong then I think it's OK to break. We should just add a 🚨 prefix to the PR title so it can be easily found when preparing the release notes
def freeze_encoders(self, freeze_text_encoder=True): | ||
if freeze_text_encoder: | ||
for param in self.text_encoder.parameters(): | ||
param.requires_grad = False | ||
self.text_encoder._requires_grad = False | ||
|
||
for param in self.audio_encoder.parameters(): | ||
param.requires_grad = False | ||
self.audio_encoder._requires_grad = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not a fan of this structure. In general, we don't add freeze
methods to our models add leave that to the user to handle - although I see audio models appear to be the exception!
It's tidier to split this up to freeze_text_encoder
and freeze_audio_encoder
and then just call them separately or add an additional freeze_audio_encoder
argument
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've split these up !
@@ -2428,6 +2442,16 @@ def _maybe_initialize_input_ids_for_generation( | |||
break | |||
return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id | |||
|
|||
def freeze_encoders(self, freeze_text_encoder=True): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment here
@@ -2533,6 +2550,16 @@ def resize_token_embeddings(self, *args, **kwargs): | |||
" model.decoder.resize_token_embeddings(...))" | |||
) | |||
|
|||
def freeze_encoders(self, freeze_text_encoder=True): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should have a docstring as its a public method
@@ -2189,6 +2288,26 @@ def test_eager_matches_sdpa_generate(self): | |||
|
|||
self.assertTrue(torch.allclose(res_eager, res_sdpa)) | |||
|
|||
def test_requires_grad_with_frozen_encoders(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
❤️
Many thanks for the review @amyeroberts, I've changed the code according to your comments! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome work - thanks for adding this feature!
""" | ||
|
||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | ||
|
||
if (labels is not None) and (input_ids is None and inputs_embeds is None): | ||
input_ids = shift_tokens_right(labels, self.config.pad_token_id, self.config.bos_token_id) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, ok! Makes sense :)
|
||
return Seq2SeqLMOutput( | ||
loss=loss, | ||
loss=decoder_outputs.loss, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it's completely wrong then I think it's OK to break. We should just add a 🚨 prefix to the PR title so it can be easily found when preparing the release notes
* first modeling code * make repository * still WIP * update model * add tests * add latest change * clean docstrings and copied from * update docstrings md and readme * correct chroma function * correct copied from and remove unreleated test * add doc to toctree * correct imports * add convert script to notdoctested * Add suggestion from Sanchit Co-authored-by: Sanchit Gandhi <[email protected]> * correct get_uncoditional_inputs docstrings * modify README according to SANCHIT feedback * add chroma to audio utils * clean librosa and torchaudio hard dependencies * fix FE * refactor audio decoder -> audio encoder for consistency with previous musicgen * refactor conditional -> encoder * modify sampling rate logics * modify license at the beginning * refactor all_self_attns->all_attentions * remove ignore copy from causallm generate * add copied from for from_sub_models * fix make copies * add warning if audio is truncated * add copied from where relevant * remove artefact * fix convert script * fix torchaudio and FE * modify chroma method according to feedback-> better naming * refactor input_values->input_features * refactor input_values->input_features and fix import fe * add input_features to docstrigs * correct inputs_embeds logics * remove dtype conversion * refactor _prepare_conditional_hidden_states_kwargs_for_generation ->_prepare_encoder_hidden_states_kwargs_for_generation * change warning for chroma length * Update src/transformers/models/musicgen_melody/convert_musicgen_melody_transformers.py Co-authored-by: Sanchit Gandhi <[email protected]> * change way to save wav, using soundfile * correct docs and change to soundfile * fix import * fix init proj layers * add draft training * fix cross entropy * clean loss computation * fix labels * remove line breaks from md * fix issue with docstrings * add FE suggestions * improve is in logics and remove useless imports * remove custom from_pretrained * simplify docstring code * add suggestions for modeling tests * make style * update converting script with sanity check * remove encoder attention mask from conditional generation * replace musicgen melody checkpoints with official orga * rename ylacombe->facebook in checkpoints * fix copies * remove unecessary warning * add shape in code docstrings * add files to slow doc tests * fix md bug and add md to not_tested * make fix-copies * fix hidden states test and batching * update training code * add training tests for melody * add training for o.g musicgen * fix copied from * remove final todos * make style * fix style * add suggestions from review * add ref to the original loss computation code * rename method + fix labels in tests * make style --------- Co-authored-by: Sanchit Gandhi <[email protected]>
This PR aims to add training compatibility for Musicgen and Musicgen Melody.
The main difference with classic cross-entropy is that there a
num_codebooks
labels to predict per timestamp instead of a single token per timestamp. This materializes in the loss which consists in the mean of cross-entropy per codebook.A few additional insights:
max_length
.- The first codebook channel is predicted without delay, but the further you go, the more delay there is (2nd codebook -> delayed by 1, 3rd codebook -> delayed by 2, etc.)
cc @sanchit-gandhi and @amyeroberts