Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🚨 Add training compatibility for Musicgen-like models #29802

Merged
merged 98 commits into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from 88 commits
Commits
Show all changes
98 commits
Select commit Hold shift + click to select a range
853d2c0
first modeling code
ylacombe Jan 3, 2024
2ff2f3d
make repository
ylacombe Jan 3, 2024
a3fa21f
still WIP
ylacombe Jan 16, 2024
4c02db4
update model
ylacombe Jan 30, 2024
b141703
add tests
ylacombe Feb 1, 2024
2b19612
add latest change
ylacombe Feb 1, 2024
eae18da
clean docstrings and copied from
ylacombe Feb 1, 2024
2285db3
update docstrings md and readme
ylacombe Feb 1, 2024
cb8f4c5
correct chroma function
ylacombe Feb 1, 2024
0ab4623
Merge branch 'main' into add-musicgen-melody
ylacombe Feb 1, 2024
c1e196d
correct copied from and remove unreleated test
ylacombe Feb 1, 2024
c8bf6c5
add doc to toctree
ylacombe Feb 1, 2024
f015753
correct imports
ylacombe Feb 1, 2024
c8c5a4e
add convert script to notdoctested
ylacombe Feb 1, 2024
2cf5cfb
Add suggestion from Sanchit
ylacombe Feb 5, 2024
bce1aaf
Merge branch 'huggingface:main' into add-musicgen-melody
ylacombe Feb 5, 2024
0e944af
correct get_uncoditional_inputs docstrings
ylacombe Feb 5, 2024
1a03cd9
modify README according to SANCHIT feedback
ylacombe Feb 5, 2024
fded84d
add chroma to audio utils
ylacombe Feb 5, 2024
133e486
clean librosa and torchaudio hard dependencies
ylacombe Feb 5, 2024
a70d0da
fix FE
ylacombe Feb 5, 2024
34c8270
refactor audio decoder -> audio encoder for consistency with previous…
ylacombe Feb 5, 2024
fdd1743
refactor conditional -> encoder
ylacombe Feb 6, 2024
b13cbcf
modify sampling rate logics
ylacombe Feb 6, 2024
2bb0adb
modify license at the beginning
ylacombe Feb 6, 2024
d06b327
refactor all_self_attns->all_attentions
ylacombe Feb 6, 2024
7842840
remove ignore copy from causallm generate
ylacombe Feb 6, 2024
8e7c128
add copied from for from_sub_models
ylacombe Feb 6, 2024
8e1bc88
fix make copies
ylacombe Feb 6, 2024
61eb704
add warning if audio is truncated
ylacombe Feb 6, 2024
e761acc
add copied from where relevant
ylacombe Feb 6, 2024
96baf7d
remove artefact
ylacombe Feb 6, 2024
357b416
fix convert script
ylacombe Feb 6, 2024
ebe4cde
fix torchaudio and FE
ylacombe Feb 6, 2024
aacf7ee
modify chroma method according to feedback-> better naming
ylacombe Feb 6, 2024
3838361
refactor input_values->input_features
ylacombe Feb 6, 2024
a68c1a0
refactor input_values->input_features and fix import fe
ylacombe Feb 6, 2024
b174155
add input_features to docstrigs
ylacombe Feb 6, 2024
f9620b9
correct inputs_embeds logics
ylacombe Feb 6, 2024
6b6d7cb
remove dtype conversion
ylacombe Feb 6, 2024
8c1d8f8
refactor _prepare_conditional_hidden_states_kwargs_for_generation ->_…
ylacombe Feb 6, 2024
4eface6
change warning for chroma length
ylacombe Feb 6, 2024
2109479
Update src/transformers/models/musicgen_melody/convert_musicgen_melod…
ylacombe Feb 6, 2024
3bfc793
change way to save wav, using soundfile
ylacombe Feb 6, 2024
9cd463a
correct docs and change to soundfile
ylacombe Feb 6, 2024
9c4aee1
fix import
ylacombe Feb 7, 2024
0fa0274
Merge branch 'huggingface:main' into add-musicgen-melody
ylacombe Feb 7, 2024
0535b57
fix init proj layers
ylacombe Feb 7, 2024
87f4cf7
Merge branch 'huggingface:main' into add-musicgen-melody
ylacombe Feb 7, 2024
74661a4
add draft training
ylacombe Feb 9, 2024
eb786b0
fix cross entropy
ylacombe Feb 9, 2024
2b8acd3
clean loss computation
ylacombe Feb 13, 2024
608d4f6
fix labels
ylacombe Feb 16, 2024
b36e802
remove line breaks from md
ylacombe Feb 19, 2024
3fd2839
fix issue with docstrings
ylacombe Feb 19, 2024
9f15d02
add FE suggestions
ylacombe Feb 19, 2024
48c2c3f
improve is in logics and remove useless imports
ylacombe Feb 19, 2024
9a43be0
remove custom from_pretrained
ylacombe Feb 19, 2024
cf89389
simplify docstring code
ylacombe Feb 19, 2024
bb69817
add suggestions for modeling tests
ylacombe Feb 19, 2024
fc33efb
make style
ylacombe Feb 19, 2024
ba4d732
update converting script with sanity check
ylacombe Feb 19, 2024
5166259
remove encoder attention mask from conditional generation
ylacombe Feb 19, 2024
ff99457
Merge pull request #1 from ylacombe/add-musicgen-melody
ylacombe Feb 19, 2024
755960a
Merge branch 'main' into add-musicgen-melody
ylacombe Feb 26, 2024
8b9177f
Merge branch 'main' into add-musicgen-melody
ylacombe Mar 4, 2024
ad26dc9
replace musicgen melody checkpoints with official orga
ylacombe Mar 4, 2024
7595256
rename ylacombe->facebook in checkpoints
ylacombe Mar 4, 2024
2576806
fix copies
ylacombe Mar 4, 2024
379d70b
remove unecessary warning
ylacombe Mar 4, 2024
9795c6f
add shape in code docstrings
ylacombe Mar 4, 2024
b03b36d
add files to slow doc tests
ylacombe Mar 4, 2024
1490150
Merge pull request #2 from ylacombe/add-musicgen-melody
ylacombe Mar 4, 2024
b434f8a
fix md bug and add md to not_tested
ylacombe Mar 5, 2024
ebeca43
Merge branch 'main' into add-musicgen-melody
ylacombe Mar 18, 2024
604a4c8
make fix-copies
ylacombe Mar 18, 2024
7bda3c3
Merge branch 'huggingface:main' into add-musicgen-melody
ylacombe Mar 18, 2024
5863cf9
fix hidden states test and batching
ylacombe Mar 18, 2024
5ba6f3b
Merge pull request #3 from ylacombe/add-musicgen-melody
ylacombe Mar 22, 2024
9c17c93
Merge branch 'main' into add-training-musicgen
ylacombe Mar 22, 2024
3f780ec
update training code
ylacombe Mar 22, 2024
7ae5d49
add training tests for melody
ylacombe Mar 22, 2024
319c668
add training for o.g musicgen
ylacombe Mar 22, 2024
3dee4fc
fix copied from
ylacombe Mar 22, 2024
213b090
remove final todos
ylacombe Mar 22, 2024
6a3706c
make style
ylacombe Mar 22, 2024
03b8bd6
Merge branch 'main' into add-training-musicgen
ylacombe Apr 2, 2024
99179fd
fix style
ylacombe Apr 2, 2024
11eb376
Merge branch 'huggingface:main' into add-training-musicgen
ylacombe Apr 16, 2024
6aa16df
add suggestions from review
ylacombe Apr 16, 2024
ff8d32e
Merge branch 'huggingface:main' into add-training-musicgen
ylacombe Apr 18, 2024
be75fad
add ref to the original loss computation code
ylacombe Apr 20, 2024
98f07b8
rename method + fix labels in tests
ylacombe Apr 20, 2024
371e154
make style
ylacombe Apr 20, 2024
52f2e6e
Merge branch 'huggingface:main' into add-training-musicgen
ylacombe Apr 25, 2024
1f95a30
Merge branch 'huggingface:main' into add-training-musicgen
ylacombe Apr 25, 2024
c559e32
Merge branch 'huggingface:main' into add-training-musicgen
ylacombe Apr 25, 2024
56c1226
Merge branch 'huggingface:main' into add-training-musicgen
ylacombe Apr 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@
("mpt", "MptModel"),
("mra", "MraModel"),
("mt5", "MT5Model"),
("musicgen", "MusicgenModel"),
("musicgen_melody", "MusicgenMelodyModel"),
("mvp", "MvpModel"),
("nat", "NatModel"),
("nezha", "NezhaModel"),
Expand Down
74 changes: 55 additions & 19 deletions src/transformers/models/musicgen/modeling_musicgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,16 +104,15 @@ class MusicgenUnconditionalInput(ModelOutput):
guidance_scale: float = None


# Copied from transformers.models.encoder_decoder.modeling_encoder_decoder.shift_tokens_right
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
"""
Shift input ids one token to the right.
"""
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
if decoder_start_token_id is None:
raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.")
shifted_input_ids[:, 0] = decoder_start_token_id
shifted_input_ids[..., 0] = decoder_start_token_id

if pad_token_id is None:
raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.")
Expand Down Expand Up @@ -909,6 +908,10 @@ def _init_weights(self, module):

If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value
of `inputs_embeds`.
labels (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
Expand Down Expand Up @@ -1340,15 +1343,22 @@ def forward(
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
labels (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

c.f. this docstring that explains how the -100 padding token should be set

are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
Returns:
Returns:
"""

return_dict = return_dict if return_dict is not None else self.config.use_return_dict

if (labels is not None) and (input_ids is None and inputs_embeds is None):
input_ids = shift_tokens_right(
labels.transpose(1, 2),
ylacombe marked this conversation as resolved.
Show resolved Hide resolved
self.config.pad_token_id,
self.config.bos_token_id,
)

outputs = self.model(
input_ids,
attention_mask=attention_mask,
Expand All @@ -1370,7 +1380,28 @@ def forward(

loss = None
if labels is not None:
raise NotImplementedError("Training is not implemented for Musicgen.")
# since encoder hidden states have been concatenated to the decoder hidden states,
# we take the last timestamps corresponding to labels
logits = lm_logits[:, :, -labels.shape[1] :]

loss_fct = CrossEntropyLoss()
loss = torch.zeros([], device=self.device)

# per codebook cross-entropy
# -100 labels are ignored
labels = labels.masked_fill(labels == self.config.pad_token_id, -100)

mask = labels != -100

# per codebook cross-entropy
for codebook in range(self.config.num_codebooks):
codebook_logits = logits[:, codebook].contiguous().view(-1, logits.shape[-1])
codebook_mask = mask[..., codebook].contiguous().view(-1)
codebook_labels = labels[..., codebook].contiguous().view(-1)

loss += loss_fct(codebook_logits[codebook_mask], codebook_labels[codebook_mask])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The padded labels should be set to -100 outside of the modelling code, i.e. in the data collator. Then, we don't need any of this masking logic, since the CE loss masks out -100 values by default (see argument ignore_index)

Suggested change
# per codebook cross-entropy
# -100 labels are ignored
labels = labels.masked_fill(labels == self.config.pad_token_id, -100)
mask = labels != -100
# per codebook cross-entropy
for codebook in range(self.config.num_codebooks):
codebook_logits = logits[:, codebook].contiguous().view(-1, logits.shape[-1])
codebook_mask = mask[..., codebook].contiguous().view(-1)
codebook_labels = labels[..., codebook].contiguous().view(-1)
loss += loss_fct(codebook_logits[codebook_mask], codebook_labels[codebook_mask])
# per codebook cross-entropy
for codebook in range(self.config.num_codebooks):
codebook_logits = logits[:, codebook].contiguous().view(-1, logits.shape[-1])
codebook_labels = labels[..., codebook].contiguous().view(-1)
loss += loss_fct(codebook_logits, codebook_labels)


loss = loss / self.config.num_codebooks
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reference from original AudioCraft code: https://github.com/facebookresearch/audiocraft/blob/69fea8b290ad1b4b40d28f92d1dfc0ab01dbab85/audiocraft/solvers/musicgen.py#L242-L243

Interesting that they average over codebooks, and not a true average over all labels. Given the sequence length for music generation is large (1500 tokens), the difference is going to be negligible.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add this link in a comment above for any unsuspecting future code reader to have context


# (bsz, num_codebooks, seq_len, vocab_size) -> (bsz * num_codebooks, seq_len, vocab_size)
lm_logits = lm_logits.reshape(-1, *lm_logits.shape[2:])
Expand Down Expand Up @@ -2234,8 +2265,9 @@ def forward(
encoder_hidden_states = encoder_hidden_states * attention_mask[..., None]

if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None):
# transpose to get (bsz, num_codebooks, seq_len)
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
labels.transpose(1, 2), self.config.decoder.pad_token_id, self.config.decoder.decoder_start_token_id
)

elif decoder_input_ids is None and decoder_inputs_embeds is None:
Expand Down Expand Up @@ -2270,23 +2302,15 @@ def forward(
use_cache=use_cache,
past_key_values=past_key_values,
return_dict=return_dict,
labels=labels,
**kwargs_decoder,
)

loss = None
if labels is not None:
logits = decoder_outputs.logits if return_dict else decoder_outputs[0]
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))

if not return_dict:
if loss is not None:
return (loss,) + decoder_outputs + encoder_outputs
else:
return decoder_outputs + encoder_outputs
return decoder_outputs + encoder_outputs

return Seq2SeqLMOutput(
loss=loss,
loss=decoder_outputs.loss,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm, the problem with this is it's not backwards compatible - users are now going to get different values of loss than before

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't seen any users using this or mentioning this tbh (besides it's totally wrong!).
How should we best handle this ? maybe adding a breaking flag in the PR name ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's completely wrong then I think it's OK to break. We should just add a 🚨 prefix to the PR title so it can be easily found when preparing the release notes

logits=decoder_outputs.logits,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
Expand Down Expand Up @@ -2524,7 +2548,9 @@ def _prepare_audio_encoder_kwargs_for_generation(
return model_kwargs

def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
return shift_tokens_right(
labels.transpose(1, 2), self.config.decoder.pad_token_id, self.config.decoder.bos_token_id
)

def resize_token_embeddings(self, *args, **kwargs):
raise NotImplementedError(
Expand All @@ -2533,6 +2559,16 @@ def resize_token_embeddings(self, *args, **kwargs):
" model.decoder.resize_token_embeddings(...))"
)

def freeze_encoders(self, freeze_text_encoder=True):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should have a docstring as its a public method

if freeze_text_encoder:
for param in self.text_encoder.parameters():
param.requires_grad = False
self.text_encoder._requires_grad = False

for param in self.audio_encoder.parameters():
param.requires_grad = False
self.audio_encoder._requires_grad = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not a fan of this structure. In general, we don't add freeze methods to our models add leave that to the user to handle - although I see audio models appear to be the exception!

It's tidier to split this up to freeze_text_encoder and freeze_audio_encoder and then just call them separately or add an additional freeze_audio_encoder argument

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've split these up !


def _maybe_initialize_input_ids_for_generation(
self,
inputs: Optional[torch.Tensor] = None,
Expand Down
71 changes: 52 additions & 19 deletions src/transformers/models/musicgen_melody/modeling_musicgen_melody.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,16 +116,16 @@ class MusicgenMelodyOutputWithPast(ModelOutput):
encoder_hidden_states: Optional[torch.FloatTensor] = None


# Copied from transformers.models.encoder_decoder.modeling_encoder_decoder.shift_tokens_right
# Copied from transformers.models.musicgen.modeling_musicgen.shift_tokens_right
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
"""
Shift input ids one token to the right.
"""
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
if decoder_start_token_id is None:
raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.")
shifted_input_ids[:, 0] = decoder_start_token_id
shifted_input_ids[..., 0] = decoder_start_token_id

if pad_token_id is None:
raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.")
Expand Down Expand Up @@ -864,7 +864,7 @@ def _init_weights(self, module):

If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value
of `inputs_embeds`.
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
labels (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
Expand Down Expand Up @@ -1269,7 +1269,7 @@ def forward(
labels: Optional[torch.LongTensor] = None,
) -> Union[Tuple, MusicgenMelodyOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
labels (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
Expand All @@ -1278,6 +1278,13 @@ def forward(

return_dict = return_dict if return_dict is not None else self.config.use_return_dict

if (labels is not None) and (input_ids is None and inputs_embeds is None):
input_ids = shift_tokens_right(
labels.transpose(1, 2),
self.config.pad_token_id,
self.config.bos_token_id,
)

outputs = self.model(
input_ids,
attention_mask=attention_mask,
Expand All @@ -1298,7 +1305,28 @@ def forward(

loss = None
if labels is not None:
raise NotImplementedError("Training is not implemented for MusicgenMelody.")
# since encoder hidden states have been concatenated to the decoder hidden states,
# we take the last timestamps corresponding to labels
logits = lm_logits[:, :, -labels.shape[1] :]

loss_fct = CrossEntropyLoss()
loss = torch.zeros([], device=self.device)

# per codebook cross-entropy
# -100 labels are ignored
labels = labels.masked_fill(labels == self.config.pad_token_id, -100)

mask = labels != -100

# per codebook cross-entropy
for codebook in range(self.config.num_codebooks):
codebook_logits = logits[:, codebook].contiguous().view(-1, logits.shape[-1])
codebook_mask = mask[..., codebook].contiguous().view(-1)
codebook_labels = labels[..., codebook].contiguous().view(-1)

loss += loss_fct(codebook_logits[codebook_mask], codebook_labels[codebook_mask])

loss = loss / self.config.num_codebooks

# (bsz, num_codebooks, seq_len, vocab_size) -> (bsz * num_codebooks, seq_len, vocab_size)
lm_logits = lm_logits.reshape(-1, *lm_logits.shape[2:])
Expand Down Expand Up @@ -2155,8 +2183,9 @@ def forward(
encoder_hidden_states = audio_hidden_states

if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None):
# transpose to get (bsz, num_codebooks, seq_len)
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
labels.transpose(1, 2), self.config.decoder.pad_token_id, self.config.decoder.bos_token_id
)

# Decode
Expand All @@ -2170,23 +2199,15 @@ def forward(
use_cache=use_cache,
past_key_values=past_key_values,
return_dict=return_dict,
labels=labels,
**kwargs_decoder,
)

loss = None
if labels is not None:
logits = decoder_outputs.logits if return_dict else decoder_outputs[0]
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))

if not return_dict:
if loss is not None:
return (loss,) + decoder_outputs + (encoder_hidden_states,)
else:
return decoder_outputs + (encoder_hidden_states,)
return decoder_outputs + (encoder_hidden_states,)

return MusicgenMelodyOutputWithPast(
loss=loss,
loss=decoder_outputs.loss,
logits=decoder_outputs.logits,
past_key_values=decoder_outputs.past_key_values,
hidden_states=decoder_outputs.hidden_states,
Expand Down Expand Up @@ -2397,7 +2418,9 @@ def _prepare_encoder_hidden_states_kwargs_for_generation(
return model_kwargs

def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
return shift_tokens_right(
labels.transpose(1, 2), self.config.decoder.pad_token_id, self.config.decoder.bos_token_id
)

def resize_token_embeddings(self, *args, **kwargs):
raise NotImplementedError(
Expand Down Expand Up @@ -2428,6 +2451,16 @@ def _maybe_initialize_input_ids_for_generation(
break
return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id

def freeze_encoders(self, freeze_text_encoder=True):
ylacombe marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment here

if freeze_text_encoder:
for param in self.text_encoder.parameters():
param.requires_grad = False
self.text_encoder._requires_grad = False

for param in self.audio_encoder.parameters():
param.requires_grad = False
self.audio_encoder._requires_grad = False

@torch.no_grad()
def generate(
self,
Expand Down
Loading
Loading