Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🚨 Add training compatibility for Musicgen-like models #29802

Merged
merged 98 commits into from
Apr 25, 2024

Conversation

ylacombe
Copy link
Contributor

@ylacombe ylacombe commented Mar 22, 2024

This PR aims to add training compatibility for Musicgen and Musicgen Melody.

The main difference with classic cross-entropy is that there a num_codebooks labels to predict per timestamp instead of a single token per timestamp. This materializes in the loss which consists in the mean of cross-entropy per codebook.

A few additional insights:

  • The models don't have an EOS token id, so the models generate for max_length.
  • The model actually predict codebooks in a delayed pattern.
    - The first codebook channel is predicted without delay, but the further you go, the more delay there is (2nd codebook -> delayed by 1, 3rd codebook -> delayed by 2, etc.)
  • Training scripts will be shared as well

cc @sanchit-gandhi and @amyeroberts

@ylacombe ylacombe marked this pull request as ready for review March 22, 2024 12:50
@ylacombe ylacombe changed the title [WIP] Add training compatibility for Musicgen-like models Add training compatibility for Musicgen-like models Mar 22, 2024
@ylacombe ylacombe requested a review from sanchit-gandhi March 22, 2024 13:07
@arjunsinghrathore
Copy link

Hi! Is it possible to finetune the musicgen model currently? If yes then is there something I should keep in mind. Would be really helpful if you could share your opinions. Thanks!

@LiuZH-19
Copy link

Wonderful work! I'm currently attempting to fine-tune the Musicgen model using these codes, but I haven't succeeded yet. Is the model ready for fine-tuning, and are there specific aspects I should be aware of? Any training tips or guidance you could provide would be greatly appreciated!

Thank you so much!

@ylacombe
Copy link
Contributor Author

Hey @arjunsinghrathore and @LiuZH-19, I'll likely release some fine-tuning code next week or the week after!
May I ask what type of data do you have, out of curiosity ?
Thanks!

Copy link
Contributor

@sanchit-gandhi sanchit-gandhi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fast tests look quite rigorous, but we can probably use the base one unless there's a specific embedding issue! Otherwise LGTM

src/transformers/models/musicgen/modeling_musicgen.py Outdated Show resolved Hide resolved
Comment on lines 1390 to 1402
# per codebook cross-entropy
# -100 labels are ignored
labels = labels.masked_fill(labels == self.config.pad_token_id, -100)

mask = labels != -100

# per codebook cross-entropy
for codebook in range(self.config.num_codebooks):
codebook_logits = logits[:, codebook].contiguous().view(-1, logits.shape[-1])
codebook_mask = mask[..., codebook].contiguous().view(-1)
codebook_labels = labels[..., codebook].contiguous().view(-1)

loss += loss_fct(codebook_logits[codebook_mask], codebook_labels[codebook_mask])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The padded labels should be set to -100 outside of the modelling code, i.e. in the data collator. Then, we don't need any of this masking logic, since the CE loss masks out -100 values by default (see argument ignore_index)

Suggested change
# per codebook cross-entropy
# -100 labels are ignored
labels = labels.masked_fill(labels == self.config.pad_token_id, -100)
mask = labels != -100
# per codebook cross-entropy
for codebook in range(self.config.num_codebooks):
codebook_logits = logits[:, codebook].contiguous().view(-1, logits.shape[-1])
codebook_mask = mask[..., codebook].contiguous().view(-1)
codebook_labels = labels[..., codebook].contiguous().view(-1)
loss += loss_fct(codebook_logits[codebook_mask], codebook_labels[codebook_mask])
# per codebook cross-entropy
for codebook in range(self.config.num_codebooks):
codebook_logits = logits[:, codebook].contiguous().view(-1, logits.shape[-1])
codebook_labels = labels[..., codebook].contiguous().view(-1)
loss += loss_fct(codebook_logits, codebook_labels)


loss += loss_fct(codebook_logits[codebook_mask], codebook_labels[codebook_mask])

loss = loss / self.config.num_codebooks
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reference from original AudioCraft code: https://github.com/facebookresearch/audiocraft/blob/69fea8b290ad1b4b40d28f92d1dfc0ab01dbab85/audiocraft/solvers/musicgen.py#L242-L243

Interesting that they average over codebooks, and not a true average over all labels. Given the sequence length for music generation is large (1500 tokens), the difference is going to be negligible.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add this link in a comment above for any unsuspecting future code reader to have context

@@ -1340,15 +1343,22 @@ def forward(
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
labels (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

c.f. this docstring that explains how the -100 padding token should be set

Comment on lines 228 to 229
# Contrarily to the initial method, we don't unfreeze freezed parameters.
# Otherwise, it'll mess with the freezed sinusoidal embeddings
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It'll change the weights of the pre-trained embeddings for sure, but the code should still run right? Or is there an issue where the code won't run if we train the embeddings? Unless this is the case, I would just use the base check_training_gradient_checkpointing method for simplicity

Copy link
Contributor Author

@ylacombe ylacombe Apr 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comments are not really clear, I'll make it clearer:

  1. the ConditionalGeneration models have the audio encoder which is never used outside of the .generate, if we don't freeze it, we'll have some issues because trainable weights won't have seen any gradients.
  2. the CausalModel's sinusoidal embeddings are frozen, and should stay frozen, (it shouldn't have been transcribed in Parameter)
    self.weights.requires_grad = False

@ylacombe ylacombe requested a review from amyeroberts April 16, 2024 17:45
@ylacombe
Copy link
Contributor Author

Hey @amyeroberts, gentle ping to ask for a review! Many thanks for your help!

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this capability!

Mostly small comments. Only concern is backwards compatibility wrt the loss for musicgen

"""

return_dict = return_dict if return_dict is not None else self.config.use_return_dict

if (labels is not None) and (input_ids is None and inputs_embeds is None):
input_ids = shift_tokens_right(labels, self.config.pad_token_id, self.config.bos_token_id)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit funny to do this on this input_ids - normally it's just on the decoder_input_ids.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, it's a bit confusing but input_ids in MusicgenForCausalLM actually corresponds to the audio input_ids (i.e the decoder input ids)!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, ok! Makes sense :)


loss += loss_fct(codebook_logits[codebook_mask], codebook_labels[codebook_mask])

loss = loss / self.config.num_codebooks
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add this link in a comment above for any unsuspecting future code reader to have context


return Seq2SeqLMOutput(
loss=loss,
loss=decoder_outputs.loss,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm, the problem with this is it's not backwards compatible - users are now going to get different values of loss than before

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't seen any users using this or mentioning this tbh (besides it's totally wrong!).
How should we best handle this ? maybe adding a breaking flag in the PR name ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's completely wrong then I think it's OK to break. We should just add a 🚨 prefix to the PR title so it can be easily found when preparing the release notes

Comment on lines 2553 to 2561
def freeze_encoders(self, freeze_text_encoder=True):
if freeze_text_encoder:
for param in self.text_encoder.parameters():
param.requires_grad = False
self.text_encoder._requires_grad = False

for param in self.audio_encoder.parameters():
param.requires_grad = False
self.audio_encoder._requires_grad = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not a fan of this structure. In general, we don't add freeze methods to our models add leave that to the user to handle - although I see audio models appear to be the exception!

It's tidier to split this up to freeze_text_encoder and freeze_audio_encoder and then just call them separately or add an additional freeze_audio_encoder argument

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've split these up !

@@ -2428,6 +2442,16 @@ def _maybe_initialize_input_ids_for_generation(
break
return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id

def freeze_encoders(self, freeze_text_encoder=True):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment here

tests/models/musicgen/test_modeling_musicgen.py Outdated Show resolved Hide resolved
tests/models/musicgen/test_modeling_musicgen.py Outdated Show resolved Hide resolved
@@ -2533,6 +2550,16 @@ def resize_token_embeddings(self, *args, **kwargs):
" model.decoder.resize_token_embeddings(...))"
)

def freeze_encoders(self, freeze_text_encoder=True):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should have a docstring as its a public method

@@ -2189,6 +2288,26 @@ def test_eager_matches_sdpa_generate(self):

self.assertTrue(torch.allclose(res_eager, res_sdpa))

def test_requires_grad_with_frozen_encoders(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❤️

@ylacombe
Copy link
Contributor Author

Many thanks for the review @amyeroberts, I've changed the code according to your comments!
The only left to address is the loss computation being breaking changes. let me know what you think of this.
Note that I don't believe a lot of users actually used the loss computation as it was.

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome work - thanks for adding this feature!

"""

return_dict = return_dict if return_dict is not None else self.config.use_return_dict

if (labels is not None) and (input_ids is None and inputs_embeds is None):
input_ids = shift_tokens_right(labels, self.config.pad_token_id, self.config.bos_token_id)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, ok! Makes sense :)


return Seq2SeqLMOutput(
loss=loss,
loss=decoder_outputs.loss,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's completely wrong then I think it's OK to break. We should just add a 🚨 prefix to the PR title so it can be easily found when preparing the release notes

@ylacombe ylacombe changed the title Add training compatibility for Musicgen-like models 🚨 Add training compatibility for Musicgen-like models Apr 25, 2024
@ylacombe ylacombe merged commit 90cb55b into huggingface:main Apr 25, 2024
23 checks passed
itazap pushed a commit that referenced this pull request May 14, 2024
* first modeling code

* make repository

* still WIP

* update model

* add tests

* add latest change

* clean docstrings and copied from

* update docstrings md and readme

* correct chroma function

* correct copied from and remove unreleated test

* add doc to toctree

* correct imports

* add convert script to notdoctested

* Add suggestion from Sanchit

Co-authored-by: Sanchit Gandhi <[email protected]>

* correct get_uncoditional_inputs docstrings

* modify README according to SANCHIT feedback

* add chroma to audio utils

* clean librosa and torchaudio hard dependencies

* fix FE

* refactor audio decoder -> audio encoder for consistency with previous musicgen

* refactor conditional -> encoder

* modify sampling rate logics

* modify license at the beginning

* refactor all_self_attns->all_attentions

* remove ignore copy from causallm generate

* add copied from for from_sub_models

* fix make copies

* add warning if audio is truncated

* add copied from where relevant

* remove artefact

* fix convert script

* fix torchaudio and FE

* modify chroma method according to feedback-> better naming

* refactor input_values->input_features

* refactor input_values->input_features and fix import fe

* add input_features to docstrigs

* correct inputs_embeds logics

* remove dtype conversion

* refactor _prepare_conditional_hidden_states_kwargs_for_generation ->_prepare_encoder_hidden_states_kwargs_for_generation

* change warning for chroma length

* Update src/transformers/models/musicgen_melody/convert_musicgen_melody_transformers.py

Co-authored-by: Sanchit Gandhi <[email protected]>

* change way to save wav, using soundfile

* correct docs and change to soundfile

* fix import

* fix init proj layers

* add draft training

* fix cross entropy

* clean loss computation

* fix labels

* remove line breaks from md

* fix issue with docstrings

* add FE suggestions

* improve is in logics and remove useless imports

* remove custom from_pretrained

* simplify docstring code

* add suggestions for modeling tests

* make style

* update converting script with sanity check

* remove encoder attention mask from conditional generation

* replace musicgen melody checkpoints with official orga

* rename ylacombe->facebook in checkpoints

* fix copies

* remove unecessary warning

* add shape in code docstrings

* add files to slow doc tests

* fix md bug and add md to not_tested

* make fix-copies

* fix hidden states test and batching

* update training code

* add training tests for melody

* add training for o.g musicgen

* fix copied from

* remove final todos

* make style

* fix style

* add suggestions from review

* add ref to the original loss computation code

* rename method + fix labels in tests

* make style

---------

Co-authored-by: Sanchit Gandhi <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants