Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding _tie_weights() to prediction heads to support low_cpu_mem_usage=True #29024

Merged
merged 5 commits into from
May 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,8 @@ class ModelArguments:
metadata={"help": "Deprecated. Please use the `language` and `task` arguments instead."},
)
suppress_tokens: List[int] = field(
default=None, metadata={
default=None,
metadata={
"help": (
"Deprecated. The use of `suppress_tokens` should not be required for the majority of fine-tuning examples."
"Should you need to use `suppress_tokens`, please manually update them in the fine-tuning script directly."
Expand Down
9 changes: 7 additions & 2 deletions src/transformers/models/albert/modeling_albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,8 +877,12 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return prediction_scores

def _tie_weights(self) -> None:
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self.bias = self.decoder.bias
# For accelerate compatibility and to not break backward compatibility
if self.decoder.bias.device.type == "meta":
self.decoder.bias = self.bias
else:
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self.bias = self.decoder.bias
Comment on lines +883 to +885
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my own understanding could you provide a bit more information here about the two cases. In particular, my questions are:

  • Which of the two branches of if/else here are for accelerate compatibility?
  • How was accelerate compatibility caught i.e. which tests?
  • Do we need both of these assignments? I'm assuming yes but it's not clear to me why one works in one case and not the other.

Copy link
Contributor Author

@hackyon hackyon Feb 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I copied this from existing tie_weights():

Here is more context on where that's added #19906.

I agree that the existing implementation of tie_weights() doesn't make sense. However, I added the if-statement to keep backward compatibility.

Here's my 2 cents on how this all works:

When loading the model, only one pointer out of the tied params would be loaded with the correct values (when device=meta). Let's call this the "canonical" pointer. All the other tied params must copy from the canonical pointer, and it can't really be the other way around (at least for device=meta case).

The canonical pointer is the key stored in the loaded state_dict. In the current logic for save_pretrained(), the "canonical" pointer is the weight key that's not in _tied_weight_keys list. In this case, that would make self.bias the canonical pointer.

However, I imagine it's still possible for some (older?) pretrained model to have some other logic of choosing which is the "canonical" pointer, and so I added the if-statement for backwards compatibility.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One note about the case where device != meta. When that happens, the direction of the pointer assignment doesn't matter as much. Whether you call self.bias = self.decoder.bias or self.decoder.bias = self.bias, when the params are loaded, they show up in both since they almost act like a pointer reference.

When device == meta, the pointer assignment operator doesn't have the same pointer reference relationship. It just copies the sizes/info of the other tensor. When one of those tensors are loaded (the canonical one), the other one needs to copy the pointer reference accordingly, it is not automatically loaded the same way.



class AlbertSOPHead(nn.Module):
Expand Down Expand Up @@ -915,6 +919,7 @@ def get_output_embeddings(self) -> nn.Linear:

def set_output_embeddings(self, new_embeddings: nn.Linear) -> None:
self.predictions.decoder = new_embeddings
self.predictions.bias = new_embeddings.bias
Comment on lines 921 to +922
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is a bit strange, we should only set embedding no the bias here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, I was caught a bit off guard here too.

The reason for this is because of resize_token_embeddings(), which uses these setters to set new embeddings. We need to copy the new bias from the embeddings since the size may have changed.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed. I just never saw one that had bias but you are right it's entirely possible


def get_input_embeddings(self) -> nn.Embedding:
return self.albert.embeddings.word_embeddings
Expand Down
6 changes: 6 additions & 0 deletions src/transformers/models/bert/modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,6 +778,9 @@ def __init__(self, config):
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias

def _tie_weights(self):
self.decoder.bias = self.bias

def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
Expand Down Expand Up @@ -1184,6 +1187,7 @@ def get_output_embeddings(self):

def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias

@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=BertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
Expand Down Expand Up @@ -1293,6 +1297,7 @@ def get_output_embeddings(self):

def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias

@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
Expand Down Expand Up @@ -1446,6 +1451,7 @@ def get_output_embeddings(self):

def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias

@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -851,8 +851,12 @@ def forward(self, hidden_states):
return logits

def _tie_weights(self):
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self.bias = self.decoder.bias
# For accelerate compatibility and to not break backward compatibility
if self.decoder.bias.device.type == "meta":
self.decoder.bias = self.bias
else:
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self.bias = self.decoder.bias


@add_start_docstrings(
Expand All @@ -879,6 +883,7 @@ def get_output_embeddings(self):

def set_output_embeddings(self, new_embeddings):
self.lm_head.decoder = new_embeddings
self.lm_head.bias = new_embeddings.bias

@add_start_docstrings_to_model_forward(BERT_GENERATION_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
Expand Down
6 changes: 6 additions & 0 deletions src/transformers/models/big_bird/modeling_big_bird.py
Original file line number Diff line number Diff line change
Expand Up @@ -1704,6 +1704,9 @@ def __init__(self, config):
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias

def _tie_weights(self):
self.decoder.bias = self.bias

def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
Expand Down Expand Up @@ -2263,6 +2266,7 @@ def get_output_embeddings(self):

def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias

@add_start_docstrings_to_model_forward(BIG_BIRD_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=BigBirdForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
Expand Down Expand Up @@ -2375,6 +2379,7 @@ def get_output_embeddings(self):

def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias

@add_start_docstrings_to_model_forward(BIG_BIRD_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC)
Expand Down Expand Up @@ -2516,6 +2521,7 @@ def get_output_embeddings(self):

def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias

@add_start_docstrings_to_model_forward(BIG_BIRD_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/blip/modeling_blip_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,9 @@ def __init__(self, config):
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias

def _tie_weights(self):
self.decoder.bias = self.bias

def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
Expand Down Expand Up @@ -818,6 +821,7 @@ def get_output_embeddings(self):

def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias

def forward(
self,
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/deberta/modeling_deberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -1021,6 +1021,7 @@ def get_output_embeddings(self):

def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias

@add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
Expand Down Expand Up @@ -1117,6 +1118,9 @@ def __init__(self, config):
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias

def _tie_weights(self):
self.decoder.bias = self.bias

def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/deberta_v2/modeling_deberta_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1120,6 +1120,7 @@ def get_output_embeddings(self):

def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias

@add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
Expand Down Expand Up @@ -1217,6 +1218,9 @@ def __init__(self, config):
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias

def _tie_weights(self):
self.decoder.bias = self.bias

def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
Expand Down
6 changes: 6 additions & 0 deletions src/transformers/models/ernie/modeling_ernie.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,6 +603,9 @@ def __init__(self, config):
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias

def _tie_weights(self):
self.decoder.bias = self.bias

def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
Expand Down Expand Up @@ -990,6 +993,7 @@ def get_output_embeddings(self):
# Copied from transformers.models.bert.modeling_bert.BertForPreTraining.set_output_embeddings
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias

@add_start_docstrings_to_model_forward(ERNIE_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=ErnieForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
Expand Down Expand Up @@ -1104,6 +1108,7 @@ def get_output_embeddings(self):
# Copied from transformers.models.bert.modeling_bert.BertLMHeadModel.set_output_embeddings
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias

@add_start_docstrings_to_model_forward(ERNIE_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
Expand Down Expand Up @@ -1264,6 +1269,7 @@ def get_output_embeddings(self):
# Copied from transformers.models.bert.modeling_bert.BertForMaskedLM.set_output_embeddings
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias

@add_start_docstrings_to_model_forward(ERNIE_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/flava/modeling_flava.py
Original file line number Diff line number Diff line change
Expand Up @@ -1661,6 +1661,9 @@ def __init__(self, config, weight=None):
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias

def _tie_weights(self):
self.decoder.bias = self.bias

def forward(self, x):
x = self.transform(x)
x = self.decoder(x)
Expand Down
12 changes: 9 additions & 3 deletions src/transformers/models/fnet/modeling_fnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,9 +355,13 @@ def forward(self, hidden_states):
hidden_states = self.decoder(hidden_states)
return hidden_states

def _tie_weights(self):
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self.bias = self.decoder.bias
def _tie_weights(self) -> None:
# For accelerate compatibility and to not break backward compatibility
if self.decoder.bias.device.type == "meta":
self.decoder.bias = self.bias
else:
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self.bias = self.decoder.bias


class FNetOnlyMLMHead(nn.Module):
Expand Down Expand Up @@ -624,6 +628,7 @@ def get_output_embeddings(self):

def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias

@add_start_docstrings_to_model_forward(FNET_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=FNetForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
Expand Down Expand Up @@ -718,6 +723,7 @@ def get_output_embeddings(self):

def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias

@add_start_docstrings_to_model_forward(FNET_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/fsmt/modeling_fsmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,9 @@ def __init__(self, config: FSMTConfig, embed_tokens: nn.Embedding):
self.output_projection = nn.Linear(embed_tokens_weight_shape[1], embed_tokens_weight_shape[0], bias=False)
self.output_projection.weight = self.embed_tokens.weight

def _tie_weights(self):
self.embed_tokens.weight = self.output_projection.weight

def forward(
self,
input_ids: torch.Tensor,
Expand Down
11 changes: 8 additions & 3 deletions src/transformers/models/ibert/modeling_ibert.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,6 +868,7 @@ def get_output_embeddings(self):

def set_output_embeddings(self, new_embeddings):
self.lm_head.decoder = new_embeddings
self.lm_head.bias = new_embeddings.bias

@add_start_docstrings_to_model_forward(IBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
Expand Down Expand Up @@ -952,9 +953,13 @@ def forward(self, features, **kwargs):

return x

def _tie_weights(self):
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self.bias = self.decoder.bias
def _tie_weights(self) -> None:
# For accelerate compatibility and to not break backward compatibility
if self.decoder.bias.device.type == "meta":
self.decoder.bias = self.bias
else:
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self.bias = self.decoder.bias


@add_start_docstrings(
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/layoutlm/modeling_layoutlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,9 @@ def __init__(self, config):
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias

def _tie_weights(self):
self.decoder.bias = self.bias

def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
Expand Down Expand Up @@ -873,6 +876,7 @@ def get_output_embeddings(self):

def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias

@add_start_docstrings_to_model_forward(LAYOUTLM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC)
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/markuplm/modeling_markuplm.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,9 @@ def __init__(self, config):
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias

def _tie_weights(self):
self.decoder.bias = self.bias

def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,9 @@ def __init__(self, config):
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias

def _tie_weights(self):
self.decoder.bias = self.bias

def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
Expand Down Expand Up @@ -1021,6 +1024,7 @@ def get_output_embeddings(self):

def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias

@add_start_docstrings_to_model_forward(MEGATRON_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=MegatronBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
Expand Down Expand Up @@ -1130,6 +1134,7 @@ def get_output_embeddings(self):

def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias

@add_start_docstrings_to_model_forward(MEGATRON_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
Expand Down Expand Up @@ -1288,6 +1293,7 @@ def get_output_embeddings(self):

def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias

@add_start_docstrings_to_model_forward(MEGATRON_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
Expand Down
13 changes: 9 additions & 4 deletions src/transformers/models/mobilebert/modeling_mobilebert.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,9 @@ def __init__(self, config):
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias

def _tie_weights(self) -> None:
self.decoder.bias = self.bias

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.transform(hidden_states)
hidden_states = hidden_states.matmul(torch.cat([self.decoder.weight.t(), self.dense.weight], dim=0))
Expand Down Expand Up @@ -938,8 +941,9 @@ def __init__(self, config):
def get_output_embeddings(self):
return self.cls.predictions.decoder

def set_output_embeddings(self, new_embeddigs):
self.cls.predictions.decoder = new_embeddigs
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias

def resize_token_embeddings(self, new_num_tokens: Optional[int] = None) -> nn.Embedding:
# resize dense output embedings at first
Expand Down Expand Up @@ -1047,8 +1051,9 @@ def __init__(self, config):
def get_output_embeddings(self):
return self.cls.predictions.decoder

def set_output_embeddings(self, new_embeddigs):
self.cls.predictions.decoder = new_embeddigs
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias

def resize_token_embeddings(self, new_num_tokens: Optional[int] = None) -> nn.Embedding:
# resize dense output embedings at first
Expand Down
Loading
Loading