-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
🚨 Add training compatibility for Musicgen-like models #29802
Changes from 88 commits
853d2c0
2ff2f3d
a3fa21f
4c02db4
b141703
2b19612
eae18da
2285db3
cb8f4c5
0ab4623
c1e196d
c8bf6c5
f015753
c8c5a4e
2cf5cfb
bce1aaf
0e944af
1a03cd9
fded84d
133e486
a70d0da
34c8270
fdd1743
b13cbcf
2bb0adb
d06b327
7842840
8e7c128
8e1bc88
61eb704
e761acc
96baf7d
357b416
ebe4cde
aacf7ee
3838361
a68c1a0
b174155
f9620b9
6b6d7cb
8c1d8f8
4eface6
2109479
3bfc793
9cd463a
9c4aee1
0fa0274
0535b57
87f4cf7
74661a4
eb786b0
2b8acd3
608d4f6
b36e802
3fd2839
9f15d02
48c2c3f
9a43be0
cf89389
bb69817
fc33efb
ba4d732
5166259
ff99457
755960a
8b9177f
ad26dc9
7595256
2576806
379d70b
9795c6f
b03b36d
1490150
b434f8a
ebeca43
604a4c8
7bda3c3
5863cf9
5ba6f3b
9c17c93
3f780ec
7ae5d49
319c668
3dee4fc
213b090
6a3706c
03b8bd6
99179fd
11eb376
6aa16df
ff8d32e
be75fad
98f07b8
371e154
52f2e6e
1f95a30
c559e32
56c1226
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -104,16 +104,15 @@ class MusicgenUnconditionalInput(ModelOutput): | |||||||||||||||||||||||||||||||||||||
guidance_scale: float = None | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
# Copied from transformers.models.encoder_decoder.modeling_encoder_decoder.shift_tokens_right | ||||||||||||||||||||||||||||||||||||||
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): | ||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||
Shift input ids one token to the right. | ||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||
shifted_input_ids = input_ids.new_zeros(input_ids.shape) | ||||||||||||||||||||||||||||||||||||||
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() | ||||||||||||||||||||||||||||||||||||||
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() | ||||||||||||||||||||||||||||||||||||||
if decoder_start_token_id is None: | ||||||||||||||||||||||||||||||||||||||
raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.") | ||||||||||||||||||||||||||||||||||||||
shifted_input_ids[:, 0] = decoder_start_token_id | ||||||||||||||||||||||||||||||||||||||
shifted_input_ids[..., 0] = decoder_start_token_id | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
if pad_token_id is None: | ||||||||||||||||||||||||||||||||||||||
raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.") | ||||||||||||||||||||||||||||||||||||||
|
@@ -909,6 +908,10 @@ def _init_weights(self, module): | |||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value | ||||||||||||||||||||||||||||||||||||||
of `inputs_embeds`. | ||||||||||||||||||||||||||||||||||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks)`, *optional*): | ||||||||||||||||||||||||||||||||||||||
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set | ||||||||||||||||||||||||||||||||||||||
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` | ||||||||||||||||||||||||||||||||||||||
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` | ||||||||||||||||||||||||||||||||||||||
use_cache (`bool`, *optional*): | ||||||||||||||||||||||||||||||||||||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see | ||||||||||||||||||||||||||||||||||||||
`past_key_values`). | ||||||||||||||||||||||||||||||||||||||
|
@@ -1340,15 +1343,22 @@ def forward( | |||||||||||||||||||||||||||||||||||||
return_dict: Optional[bool] = None, | ||||||||||||||||||||||||||||||||||||||
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: | ||||||||||||||||||||||||||||||||||||||
r""" | ||||||||||||||||||||||||||||||||||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): | ||||||||||||||||||||||||||||||||||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks)`, *optional*): | ||||||||||||||||||||||||||||||||||||||
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set | ||||||||||||||||||||||||||||||||||||||
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` | ||||||||||||||||||||||||||||||||||||||
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` | ||||||||||||||||||||||||||||||||||||||
Returns: | ||||||||||||||||||||||||||||||||||||||
Returns: | ||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
if (labels is not None) and (input_ids is None and inputs_embeds is None): | ||||||||||||||||||||||||||||||||||||||
input_ids = shift_tokens_right( | ||||||||||||||||||||||||||||||||||||||
labels.transpose(1, 2), | ||||||||||||||||||||||||||||||||||||||
ylacombe marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||
self.config.pad_token_id, | ||||||||||||||||||||||||||||||||||||||
self.config.bos_token_id, | ||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
outputs = self.model( | ||||||||||||||||||||||||||||||||||||||
input_ids, | ||||||||||||||||||||||||||||||||||||||
attention_mask=attention_mask, | ||||||||||||||||||||||||||||||||||||||
|
@@ -1370,7 +1380,28 @@ def forward( | |||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
loss = None | ||||||||||||||||||||||||||||||||||||||
if labels is not None: | ||||||||||||||||||||||||||||||||||||||
raise NotImplementedError("Training is not implemented for Musicgen.") | ||||||||||||||||||||||||||||||||||||||
# since encoder hidden states have been concatenated to the decoder hidden states, | ||||||||||||||||||||||||||||||||||||||
# we take the last timestamps corresponding to labels | ||||||||||||||||||||||||||||||||||||||
logits = lm_logits[:, :, -labels.shape[1] :] | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
loss_fct = CrossEntropyLoss() | ||||||||||||||||||||||||||||||||||||||
loss = torch.zeros([], device=self.device) | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
# per codebook cross-entropy | ||||||||||||||||||||||||||||||||||||||
# -100 labels are ignored | ||||||||||||||||||||||||||||||||||||||
labels = labels.masked_fill(labels == self.config.pad_token_id, -100) | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
mask = labels != -100 | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
# per codebook cross-entropy | ||||||||||||||||||||||||||||||||||||||
for codebook in range(self.config.num_codebooks): | ||||||||||||||||||||||||||||||||||||||
codebook_logits = logits[:, codebook].contiguous().view(-1, logits.shape[-1]) | ||||||||||||||||||||||||||||||||||||||
codebook_mask = mask[..., codebook].contiguous().view(-1) | ||||||||||||||||||||||||||||||||||||||
codebook_labels = labels[..., codebook].contiguous().view(-1) | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
loss += loss_fct(codebook_logits[codebook_mask], codebook_labels[codebook_mask]) | ||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The padded labels should be set to
Suggested change
|
||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
loss = loss / self.config.num_codebooks | ||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Reference from original AudioCraft code: https://github.com/facebookresearch/audiocraft/blob/69fea8b290ad1b4b40d28f92d1dfc0ab01dbab85/audiocraft/solvers/musicgen.py#L242-L243 Interesting that they average over codebooks, and not a true average over all labels. Given the sequence length for music generation is large (1500 tokens), the difference is going to be negligible. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's add this link in a comment above for any unsuspecting future code reader to have context |
||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
# (bsz, num_codebooks, seq_len, vocab_size) -> (bsz * num_codebooks, seq_len, vocab_size) | ||||||||||||||||||||||||||||||||||||||
lm_logits = lm_logits.reshape(-1, *lm_logits.shape[2:]) | ||||||||||||||||||||||||||||||||||||||
|
@@ -2234,8 +2265,9 @@ def forward( | |||||||||||||||||||||||||||||||||||||
encoder_hidden_states = encoder_hidden_states * attention_mask[..., None] | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None): | ||||||||||||||||||||||||||||||||||||||
# transpose to get (bsz, num_codebooks, seq_len) | ||||||||||||||||||||||||||||||||||||||
decoder_input_ids = shift_tokens_right( | ||||||||||||||||||||||||||||||||||||||
labels, self.config.pad_token_id, self.config.decoder_start_token_id | ||||||||||||||||||||||||||||||||||||||
labels.transpose(1, 2), self.config.decoder.pad_token_id, self.config.decoder.decoder_start_token_id | ||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
elif decoder_input_ids is None and decoder_inputs_embeds is None: | ||||||||||||||||||||||||||||||||||||||
|
@@ -2270,23 +2302,15 @@ def forward( | |||||||||||||||||||||||||||||||||||||
use_cache=use_cache, | ||||||||||||||||||||||||||||||||||||||
past_key_values=past_key_values, | ||||||||||||||||||||||||||||||||||||||
return_dict=return_dict, | ||||||||||||||||||||||||||||||||||||||
labels=labels, | ||||||||||||||||||||||||||||||||||||||
**kwargs_decoder, | ||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
loss = None | ||||||||||||||||||||||||||||||||||||||
if labels is not None: | ||||||||||||||||||||||||||||||||||||||
logits = decoder_outputs.logits if return_dict else decoder_outputs[0] | ||||||||||||||||||||||||||||||||||||||
loss_fct = CrossEntropyLoss() | ||||||||||||||||||||||||||||||||||||||
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
if not return_dict: | ||||||||||||||||||||||||||||||||||||||
if loss is not None: | ||||||||||||||||||||||||||||||||||||||
return (loss,) + decoder_outputs + encoder_outputs | ||||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||||
return decoder_outputs + encoder_outputs | ||||||||||||||||||||||||||||||||||||||
return decoder_outputs + encoder_outputs | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
return Seq2SeqLMOutput( | ||||||||||||||||||||||||||||||||||||||
loss=loss, | ||||||||||||||||||||||||||||||||||||||
loss=decoder_outputs.loss, | ||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hmmm, the problem with this is it's not backwards compatible - users are now going to get different values of loss than before There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I haven't seen any users using this or mentioning this tbh (besides it's totally wrong!). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If it's completely wrong then I think it's OK to break. We should just add a 🚨 prefix to the PR title so it can be easily found when preparing the release notes |
||||||||||||||||||||||||||||||||||||||
logits=decoder_outputs.logits, | ||||||||||||||||||||||||||||||||||||||
past_key_values=decoder_outputs.past_key_values, | ||||||||||||||||||||||||||||||||||||||
decoder_hidden_states=decoder_outputs.hidden_states, | ||||||||||||||||||||||||||||||||||||||
|
@@ -2524,7 +2548,9 @@ def _prepare_audio_encoder_kwargs_for_generation( | |||||||||||||||||||||||||||||||||||||
return model_kwargs | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): | ||||||||||||||||||||||||||||||||||||||
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) | ||||||||||||||||||||||||||||||||||||||
return shift_tokens_right( | ||||||||||||||||||||||||||||||||||||||
labels.transpose(1, 2), self.config.decoder.pad_token_id, self.config.decoder.bos_token_id | ||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
def resize_token_embeddings(self, *args, **kwargs): | ||||||||||||||||||||||||||||||||||||||
raise NotImplementedError( | ||||||||||||||||||||||||||||||||||||||
|
@@ -2533,6 +2559,16 @@ def resize_token_embeddings(self, *args, **kwargs): | |||||||||||||||||||||||||||||||||||||
" model.decoder.resize_token_embeddings(...))" | ||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
def freeze_encoders(self, freeze_text_encoder=True): | ||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should have a docstring as its a public method |
||||||||||||||||||||||||||||||||||||||
if freeze_text_encoder: | ||||||||||||||||||||||||||||||||||||||
for param in self.text_encoder.parameters(): | ||||||||||||||||||||||||||||||||||||||
param.requires_grad = False | ||||||||||||||||||||||||||||||||||||||
self.text_encoder._requires_grad = False | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
for param in self.audio_encoder.parameters(): | ||||||||||||||||||||||||||||||||||||||
param.requires_grad = False | ||||||||||||||||||||||||||||||||||||||
self.audio_encoder._requires_grad = False | ||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not a fan of this structure. In general, we don't add It's tidier to split this up to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've split these up ! |
||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
def _maybe_initialize_input_ids_for_generation( | ||||||||||||||||||||||||||||||||||||||
self, | ||||||||||||||||||||||||||||||||||||||
inputs: Optional[torch.Tensor] = None, | ||||||||||||||||||||||||||||||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -116,16 +116,16 @@ class MusicgenMelodyOutputWithPast(ModelOutput): | |
encoder_hidden_states: Optional[torch.FloatTensor] = None | ||
|
||
|
||
# Copied from transformers.models.encoder_decoder.modeling_encoder_decoder.shift_tokens_right | ||
# Copied from transformers.models.musicgen.modeling_musicgen.shift_tokens_right | ||
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): | ||
""" | ||
Shift input ids one token to the right. | ||
""" | ||
shifted_input_ids = input_ids.new_zeros(input_ids.shape) | ||
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() | ||
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() | ||
if decoder_start_token_id is None: | ||
raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.") | ||
shifted_input_ids[:, 0] = decoder_start_token_id | ||
shifted_input_ids[..., 0] = decoder_start_token_id | ||
|
||
if pad_token_id is None: | ||
raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.") | ||
|
@@ -864,7 +864,7 @@ def _init_weights(self, module): | |
|
||
If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value | ||
of `inputs_embeds`. | ||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): | ||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks)`, *optional*): | ||
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set | ||
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` | ||
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` | ||
|
@@ -1269,7 +1269,7 @@ def forward( | |
labels: Optional[torch.LongTensor] = None, | ||
) -> Union[Tuple, MusicgenMelodyOutputWithPast]: | ||
r""" | ||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): | ||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks)`, *optional*): | ||
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set | ||
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` | ||
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` | ||
|
@@ -1278,6 +1278,13 @@ def forward( | |
|
||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | ||
|
||
if (labels is not None) and (input_ids is None and inputs_embeds is None): | ||
input_ids = shift_tokens_right( | ||
labels.transpose(1, 2), | ||
self.config.pad_token_id, | ||
self.config.bos_token_id, | ||
) | ||
|
||
outputs = self.model( | ||
input_ids, | ||
attention_mask=attention_mask, | ||
|
@@ -1298,7 +1305,28 @@ def forward( | |
|
||
loss = None | ||
if labels is not None: | ||
raise NotImplementedError("Training is not implemented for MusicgenMelody.") | ||
# since encoder hidden states have been concatenated to the decoder hidden states, | ||
# we take the last timestamps corresponding to labels | ||
logits = lm_logits[:, :, -labels.shape[1] :] | ||
|
||
loss_fct = CrossEntropyLoss() | ||
loss = torch.zeros([], device=self.device) | ||
|
||
# per codebook cross-entropy | ||
# -100 labels are ignored | ||
labels = labels.masked_fill(labels == self.config.pad_token_id, -100) | ||
|
||
mask = labels != -100 | ||
|
||
# per codebook cross-entropy | ||
for codebook in range(self.config.num_codebooks): | ||
codebook_logits = logits[:, codebook].contiguous().view(-1, logits.shape[-1]) | ||
codebook_mask = mask[..., codebook].contiguous().view(-1) | ||
codebook_labels = labels[..., codebook].contiguous().view(-1) | ||
|
||
loss += loss_fct(codebook_logits[codebook_mask], codebook_labels[codebook_mask]) | ||
|
||
loss = loss / self.config.num_codebooks | ||
|
||
# (bsz, num_codebooks, seq_len, vocab_size) -> (bsz * num_codebooks, seq_len, vocab_size) | ||
lm_logits = lm_logits.reshape(-1, *lm_logits.shape[2:]) | ||
|
@@ -2155,8 +2183,9 @@ def forward( | |
encoder_hidden_states = audio_hidden_states | ||
|
||
if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None): | ||
# transpose to get (bsz, num_codebooks, seq_len) | ||
decoder_input_ids = shift_tokens_right( | ||
labels, self.config.pad_token_id, self.config.decoder_start_token_id | ||
labels.transpose(1, 2), self.config.decoder.pad_token_id, self.config.decoder.bos_token_id | ||
) | ||
|
||
# Decode | ||
|
@@ -2170,23 +2199,15 @@ def forward( | |
use_cache=use_cache, | ||
past_key_values=past_key_values, | ||
return_dict=return_dict, | ||
labels=labels, | ||
**kwargs_decoder, | ||
) | ||
|
||
loss = None | ||
if labels is not None: | ||
logits = decoder_outputs.logits if return_dict else decoder_outputs[0] | ||
loss_fct = CrossEntropyLoss() | ||
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) | ||
|
||
if not return_dict: | ||
if loss is not None: | ||
return (loss,) + decoder_outputs + (encoder_hidden_states,) | ||
else: | ||
return decoder_outputs + (encoder_hidden_states,) | ||
return decoder_outputs + (encoder_hidden_states,) | ||
|
||
return MusicgenMelodyOutputWithPast( | ||
loss=loss, | ||
loss=decoder_outputs.loss, | ||
logits=decoder_outputs.logits, | ||
past_key_values=decoder_outputs.past_key_values, | ||
hidden_states=decoder_outputs.hidden_states, | ||
|
@@ -2397,7 +2418,9 @@ def _prepare_encoder_hidden_states_kwargs_for_generation( | |
return model_kwargs | ||
|
||
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): | ||
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) | ||
return shift_tokens_right( | ||
labels.transpose(1, 2), self.config.decoder.pad_token_id, self.config.decoder.bos_token_id | ||
) | ||
|
||
def resize_token_embeddings(self, *args, **kwargs): | ||
raise NotImplementedError( | ||
|
@@ -2428,6 +2451,16 @@ def _maybe_initialize_input_ids_for_generation( | |
break | ||
return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id | ||
|
||
def freeze_encoders(self, freeze_text_encoder=True): | ||
ylacombe marked this conversation as resolved.
Show resolved
Hide resolved
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same comment here |
||
if freeze_text_encoder: | ||
for param in self.text_encoder.parameters(): | ||
param.requires_grad = False | ||
self.text_encoder._requires_grad = False | ||
|
||
for param in self.audio_encoder.parameters(): | ||
param.requires_grad = False | ||
self.audio_encoder._requires_grad = False | ||
|
||
@torch.no_grad() | ||
def generate( | ||
self, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
c.f. this docstring that explains how the
-100
padding token should be set